Skip to main content

openauth_plugins/device_authorization/
store.rs

1use openauth_core::crypto::random::generate_random_string;
2use openauth_core::db::{Create, DbAdapter, DbRecord, DbValue, Delete, FindOne, Update, Where};
3use openauth_core::error::OpenAuthError;
4use serde::{Deserialize, Serialize};
5use time::OffsetDateTime;
6
7use super::schema::DEVICE_CODE_MODEL;
8
9const DEVICE_CODE_FIELDS: [&str; 12] = [
10    "id",
11    "deviceCode",
12    "userCode",
13    "userId",
14    "expiresAt",
15    "status",
16    "lastPolledAt",
17    "pollingInterval",
18    "clientId",
19    "scope",
20    "createdAt",
21    "updatedAt",
22];
23const DEFAULT_ID_LENGTH: usize = 32;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26#[serde(rename_all = "lowercase")]
27pub enum DeviceAuthorizationStatus {
28    Pending,
29    Approved,
30    Denied,
31}
32
33impl DeviceAuthorizationStatus {
34    pub fn as_str(self) -> &'static str {
35        match self {
36            Self::Pending => "pending",
37            Self::Approved => "approved",
38            Self::Denied => "denied",
39        }
40    }
41}
42
43impl TryFrom<&str> for DeviceAuthorizationStatus {
44    type Error = OpenAuthError;
45
46    fn try_from(value: &str) -> Result<Self, Self::Error> {
47        match value {
48            "pending" => Ok(Self::Pending),
49            "approved" => Ok(Self::Approved),
50            "denied" => Ok(Self::Denied),
51            _ => Err(OpenAuthError::Adapter(format!(
52                "device code status `{value}` is invalid"
53            ))),
54        }
55    }
56}
57
58#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
59pub struct DeviceCodeRecord {
60    pub id: String,
61    pub device_code: String,
62    pub user_code: String,
63    pub user_id: Option<String>,
64    pub expires_at: OffsetDateTime,
65    pub status: DeviceAuthorizationStatus,
66    pub last_polled_at: Option<OffsetDateTime>,
67    pub polling_interval: Option<i64>,
68    pub client_id: Option<String>,
69    pub scope: Option<String>,
70    pub created_at: OffsetDateTime,
71    pub updated_at: OffsetDateTime,
72}
73
74#[derive(Debug, Clone, PartialEq, Eq)]
75pub struct CreateDeviceCodeInput {
76    pub device_code: String,
77    pub user_code: String,
78    pub expires_at: OffsetDateTime,
79    pub polling_interval: i64,
80    pub client_id: String,
81    pub scope: Option<String>,
82}
83
84#[derive(Clone, Copy)]
85pub struct DeviceCodeStore<'a> {
86    adapter: &'a dyn DbAdapter,
87}
88
89impl<'a> DeviceCodeStore<'a> {
90    pub fn new(adapter: &'a dyn DbAdapter) -> Self {
91        Self { adapter }
92    }
93
94    pub async fn create(
95        &self,
96        input: CreateDeviceCodeInput,
97    ) -> Result<DeviceCodeRecord, OpenAuthError> {
98        let now = OffsetDateTime::now_utc();
99        let record = self
100            .adapter
101            .create(
102                Create::new(DEVICE_CODE_MODEL)
103                    .data(
104                        "id",
105                        DbValue::String(generate_random_string(DEFAULT_ID_LENGTH)),
106                    )
107                    .data("deviceCode", DbValue::String(input.device_code))
108                    .data("userCode", DbValue::String(input.user_code))
109                    .data("userId", DbValue::Null)
110                    .data("expiresAt", DbValue::Timestamp(input.expires_at))
111                    .data(
112                        "status",
113                        DbValue::String(DeviceAuthorizationStatus::Pending.as_str().to_owned()),
114                    )
115                    .data("lastPolledAt", DbValue::Null)
116                    .data("pollingInterval", DbValue::Number(input.polling_interval))
117                    .data("clientId", DbValue::String(input.client_id))
118                    .data("scope", optional_string(input.scope))
119                    .data("createdAt", DbValue::Timestamp(now))
120                    .data("updatedAt", DbValue::Timestamp(now))
121                    .select(DEVICE_CODE_FIELDS)
122                    .force_allow_id(),
123            )
124            .await?;
125        record_from_db(record)
126    }
127
128    pub async fn find_by_device_code(
129        &self,
130        device_code: &str,
131    ) -> Result<Option<DeviceCodeRecord>, OpenAuthError> {
132        self.find_one(Where::new(
133            "deviceCode",
134            DbValue::String(device_code.to_owned()),
135        ))
136        .await
137    }
138
139    pub async fn find_by_user_code(
140        &self,
141        user_code: &str,
142    ) -> Result<Option<DeviceCodeRecord>, OpenAuthError> {
143        self.find_one(Where::new(
144            "userCode",
145            DbValue::String(user_code.to_owned()),
146        ))
147        .await
148    }
149
150    pub async fn mark_polled(&self, id: &str) -> Result<Option<DeviceCodeRecord>, OpenAuthError> {
151        self.update(
152            id,
153            DbRecord::from([(
154                "lastPolledAt".to_owned(),
155                DbValue::Timestamp(OffsetDateTime::now_utc()),
156            )]),
157        )
158        .await
159    }
160
161    pub async fn approve(
162        &self,
163        id: &str,
164        user_id: &str,
165    ) -> Result<Option<DeviceCodeRecord>, OpenAuthError> {
166        self.update(
167            id,
168            DbRecord::from([
169                (
170                    "status".to_owned(),
171                    DbValue::String(DeviceAuthorizationStatus::Approved.as_str().to_owned()),
172                ),
173                ("userId".to_owned(), DbValue::String(user_id.to_owned())),
174            ]),
175        )
176        .await
177    }
178
179    pub async fn deny(
180        &self,
181        id: &str,
182        user_id: &str,
183    ) -> Result<Option<DeviceCodeRecord>, OpenAuthError> {
184        self.update(
185            id,
186            DbRecord::from([
187                (
188                    "status".to_owned(),
189                    DbValue::String(DeviceAuthorizationStatus::Denied.as_str().to_owned()),
190                ),
191                ("userId".to_owned(), DbValue::String(user_id.to_owned())),
192            ]),
193        )
194        .await
195    }
196
197    pub async fn delete(&self, id: &str) -> Result<(), OpenAuthError> {
198        self.adapter
199            .delete(Delete::new(DEVICE_CODE_MODEL).where_clause(id_where(id)))
200            .await
201    }
202
203    async fn find_one(
204        &self,
205        where_clause: Where,
206    ) -> Result<Option<DeviceCodeRecord>, OpenAuthError> {
207        self.adapter
208            .find_one(
209                FindOne::new(DEVICE_CODE_MODEL)
210                    .where_clause(where_clause)
211                    .select(DEVICE_CODE_FIELDS),
212            )
213            .await?
214            .map(record_from_db)
215            .transpose()
216    }
217
218    async fn update(
219        &self,
220        id: &str,
221        data: DbRecord,
222    ) -> Result<Option<DeviceCodeRecord>, OpenAuthError> {
223        let mut query = Update::new(DEVICE_CODE_MODEL)
224            .where_clause(id_where(id))
225            .data("updatedAt", DbValue::Timestamp(OffsetDateTime::now_utc()));
226        for (field, value) in data {
227            query = query.data(field, value);
228        }
229
230        self.adapter
231            .update(query)
232            .await?
233            .map(record_from_db)
234            .transpose()
235    }
236}
237
238fn id_where(id: &str) -> Where {
239    Where::new("id", DbValue::String(id.to_owned()))
240}
241
242fn optional_string(value: Option<String>) -> DbValue {
243    value.map(DbValue::String).unwrap_or(DbValue::Null)
244}
245
246fn record_from_db(record: DbRecord) -> Result<DeviceCodeRecord, OpenAuthError> {
247    Ok(DeviceCodeRecord {
248        id: required_string(&record, "id")?.to_owned(),
249        device_code: required_string(&record, "deviceCode")?.to_owned(),
250        user_code: required_string(&record, "userCode")?.to_owned(),
251        user_id: optional_string_field(&record, "userId")?,
252        expires_at: required_timestamp(&record, "expiresAt")?,
253        status: DeviceAuthorizationStatus::try_from(required_string(&record, "status")?)?,
254        last_polled_at: optional_timestamp(&record, "lastPolledAt")?,
255        polling_interval: optional_number(&record, "pollingInterval")?,
256        client_id: optional_string_field(&record, "clientId")?,
257        scope: optional_string_field(&record, "scope")?,
258        created_at: required_timestamp(&record, "createdAt")?,
259        updated_at: required_timestamp(&record, "updatedAt")?,
260    })
261}
262
263fn required_string<'a>(record: &'a DbRecord, field: &str) -> Result<&'a str, OpenAuthError> {
264    match record.get(field) {
265        Some(DbValue::String(value)) => Ok(value),
266        Some(_) => Err(invalid_field(field, "string")),
267        None => Err(missing_field(field)),
268    }
269}
270
271fn optional_string_field(record: &DbRecord, field: &str) -> Result<Option<String>, OpenAuthError> {
272    match record.get(field) {
273        Some(DbValue::String(value)) => Ok(Some(value.to_owned())),
274        Some(DbValue::Null) | None => Ok(None),
275        Some(_) => Err(invalid_field(field, "string or null")),
276    }
277}
278
279fn required_timestamp(record: &DbRecord, field: &str) -> Result<OffsetDateTime, OpenAuthError> {
280    match record.get(field) {
281        Some(DbValue::Timestamp(value)) => Ok(*value),
282        Some(_) => Err(invalid_field(field, "timestamp")),
283        None => Err(missing_field(field)),
284    }
285}
286
287fn optional_timestamp(
288    record: &DbRecord,
289    field: &str,
290) -> Result<Option<OffsetDateTime>, OpenAuthError> {
291    match record.get(field) {
292        Some(DbValue::Timestamp(value)) => Ok(Some(*value)),
293        Some(DbValue::Null) | None => Ok(None),
294        Some(_) => Err(invalid_field(field, "timestamp or null")),
295    }
296}
297
298fn optional_number(record: &DbRecord, field: &str) -> Result<Option<i64>, OpenAuthError> {
299    match record.get(field) {
300        Some(DbValue::Number(value)) => Ok(Some(*value)),
301        Some(DbValue::Null) | None => Ok(None),
302        Some(_) => Err(invalid_field(field, "number or null")),
303    }
304}
305
306fn missing_field(field: &str) -> OpenAuthError {
307    OpenAuthError::Adapter(format!("device code record is missing `{field}`"))
308}
309
310fn invalid_field(field: &str, expected: &str) -> OpenAuthError {
311    OpenAuthError::Adapter(format!(
312        "device code record field `{field}` must be {expected}"
313    ))
314}