1use openauth_core::crypto::random::generate_random_string;
2use openauth_core::db::{Create, DbAdapter, DbRecord, DbValue, Delete, FindOne, Update, Where};
3use openauth_core::error::OpenAuthError;
4use serde::{Deserialize, Serialize};
5use time::OffsetDateTime;
6
7use super::schema::DEVICE_CODE_MODEL;
8
9const DEVICE_CODE_FIELDS: [&str; 12] = [
10 "id",
11 "deviceCode",
12 "userCode",
13 "userId",
14 "expiresAt",
15 "status",
16 "lastPolledAt",
17 "pollingInterval",
18 "clientId",
19 "scope",
20 "createdAt",
21 "updatedAt",
22];
23const DEFAULT_ID_LENGTH: usize = 32;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26#[serde(rename_all = "lowercase")]
27pub enum DeviceAuthorizationStatus {
28 Pending,
29 Approved,
30 Denied,
31}
32
33impl DeviceAuthorizationStatus {
34 pub fn as_str(self) -> &'static str {
35 match self {
36 Self::Pending => "pending",
37 Self::Approved => "approved",
38 Self::Denied => "denied",
39 }
40 }
41}
42
43impl TryFrom<&str> for DeviceAuthorizationStatus {
44 type Error = OpenAuthError;
45
46 fn try_from(value: &str) -> Result<Self, Self::Error> {
47 match value {
48 "pending" => Ok(Self::Pending),
49 "approved" => Ok(Self::Approved),
50 "denied" => Ok(Self::Denied),
51 _ => Err(OpenAuthError::Adapter(format!(
52 "device code status `{value}` is invalid"
53 ))),
54 }
55 }
56}
57
58#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
59pub struct DeviceCodeRecord {
60 pub id: String,
61 pub device_code: String,
62 pub user_code: String,
63 pub user_id: Option<String>,
64 pub expires_at: OffsetDateTime,
65 pub status: DeviceAuthorizationStatus,
66 pub last_polled_at: Option<OffsetDateTime>,
67 pub polling_interval: Option<i64>,
68 pub client_id: Option<String>,
69 pub scope: Option<String>,
70 pub created_at: OffsetDateTime,
71 pub updated_at: OffsetDateTime,
72}
73
74#[derive(Debug, Clone, PartialEq, Eq)]
75pub struct CreateDeviceCodeInput {
76 pub device_code: String,
77 pub user_code: String,
78 pub expires_at: OffsetDateTime,
79 pub polling_interval: i64,
80 pub client_id: String,
81 pub scope: Option<String>,
82}
83
84#[derive(Clone, Copy)]
85pub struct DeviceCodeStore<'a> {
86 adapter: &'a dyn DbAdapter,
87}
88
89impl<'a> DeviceCodeStore<'a> {
90 pub fn new(adapter: &'a dyn DbAdapter) -> Self {
91 Self { adapter }
92 }
93
94 pub async fn create(
95 &self,
96 input: CreateDeviceCodeInput,
97 ) -> Result<DeviceCodeRecord, OpenAuthError> {
98 let now = OffsetDateTime::now_utc();
99 let record = self
100 .adapter
101 .create(
102 Create::new(DEVICE_CODE_MODEL)
103 .data(
104 "id",
105 DbValue::String(generate_random_string(DEFAULT_ID_LENGTH)),
106 )
107 .data("deviceCode", DbValue::String(input.device_code))
108 .data("userCode", DbValue::String(input.user_code))
109 .data("userId", DbValue::Null)
110 .data("expiresAt", DbValue::Timestamp(input.expires_at))
111 .data(
112 "status",
113 DbValue::String(DeviceAuthorizationStatus::Pending.as_str().to_owned()),
114 )
115 .data("lastPolledAt", DbValue::Null)
116 .data("pollingInterval", DbValue::Number(input.polling_interval))
117 .data("clientId", DbValue::String(input.client_id))
118 .data("scope", optional_string(input.scope))
119 .data("createdAt", DbValue::Timestamp(now))
120 .data("updatedAt", DbValue::Timestamp(now))
121 .select(DEVICE_CODE_FIELDS)
122 .force_allow_id(),
123 )
124 .await?;
125 record_from_db(record)
126 }
127
128 pub async fn find_by_device_code(
129 &self,
130 device_code: &str,
131 ) -> Result<Option<DeviceCodeRecord>, OpenAuthError> {
132 self.find_one(Where::new(
133 "deviceCode",
134 DbValue::String(device_code.to_owned()),
135 ))
136 .await
137 }
138
139 pub async fn find_by_user_code(
140 &self,
141 user_code: &str,
142 ) -> Result<Option<DeviceCodeRecord>, OpenAuthError> {
143 self.find_one(Where::new(
144 "userCode",
145 DbValue::String(user_code.to_owned()),
146 ))
147 .await
148 }
149
150 pub async fn mark_polled(&self, id: &str) -> Result<Option<DeviceCodeRecord>, OpenAuthError> {
151 self.update(
152 id,
153 DbRecord::from([(
154 "lastPolledAt".to_owned(),
155 DbValue::Timestamp(OffsetDateTime::now_utc()),
156 )]),
157 )
158 .await
159 }
160
161 pub async fn approve(
162 &self,
163 id: &str,
164 user_id: &str,
165 ) -> Result<Option<DeviceCodeRecord>, OpenAuthError> {
166 self.update(
167 id,
168 DbRecord::from([
169 (
170 "status".to_owned(),
171 DbValue::String(DeviceAuthorizationStatus::Approved.as_str().to_owned()),
172 ),
173 ("userId".to_owned(), DbValue::String(user_id.to_owned())),
174 ]),
175 )
176 .await
177 }
178
179 pub async fn deny(
180 &self,
181 id: &str,
182 user_id: &str,
183 ) -> Result<Option<DeviceCodeRecord>, OpenAuthError> {
184 self.update(
185 id,
186 DbRecord::from([
187 (
188 "status".to_owned(),
189 DbValue::String(DeviceAuthorizationStatus::Denied.as_str().to_owned()),
190 ),
191 ("userId".to_owned(), DbValue::String(user_id.to_owned())),
192 ]),
193 )
194 .await
195 }
196
197 pub async fn delete(&self, id: &str) -> Result<(), OpenAuthError> {
198 self.adapter
199 .delete(Delete::new(DEVICE_CODE_MODEL).where_clause(id_where(id)))
200 .await
201 }
202
203 async fn find_one(
204 &self,
205 where_clause: Where,
206 ) -> Result<Option<DeviceCodeRecord>, OpenAuthError> {
207 self.adapter
208 .find_one(
209 FindOne::new(DEVICE_CODE_MODEL)
210 .where_clause(where_clause)
211 .select(DEVICE_CODE_FIELDS),
212 )
213 .await?
214 .map(record_from_db)
215 .transpose()
216 }
217
218 async fn update(
219 &self,
220 id: &str,
221 data: DbRecord,
222 ) -> Result<Option<DeviceCodeRecord>, OpenAuthError> {
223 let mut query = Update::new(DEVICE_CODE_MODEL)
224 .where_clause(id_where(id))
225 .data("updatedAt", DbValue::Timestamp(OffsetDateTime::now_utc()));
226 for (field, value) in data {
227 query = query.data(field, value);
228 }
229
230 self.adapter
231 .update(query)
232 .await?
233 .map(record_from_db)
234 .transpose()
235 }
236}
237
238fn id_where(id: &str) -> Where {
239 Where::new("id", DbValue::String(id.to_owned()))
240}
241
242fn optional_string(value: Option<String>) -> DbValue {
243 value.map(DbValue::String).unwrap_or(DbValue::Null)
244}
245
246fn record_from_db(record: DbRecord) -> Result<DeviceCodeRecord, OpenAuthError> {
247 Ok(DeviceCodeRecord {
248 id: required_string(&record, "id")?.to_owned(),
249 device_code: required_string(&record, "deviceCode")?.to_owned(),
250 user_code: required_string(&record, "userCode")?.to_owned(),
251 user_id: optional_string_field(&record, "userId")?,
252 expires_at: required_timestamp(&record, "expiresAt")?,
253 status: DeviceAuthorizationStatus::try_from(required_string(&record, "status")?)?,
254 last_polled_at: optional_timestamp(&record, "lastPolledAt")?,
255 polling_interval: optional_number(&record, "pollingInterval")?,
256 client_id: optional_string_field(&record, "clientId")?,
257 scope: optional_string_field(&record, "scope")?,
258 created_at: required_timestamp(&record, "createdAt")?,
259 updated_at: required_timestamp(&record, "updatedAt")?,
260 })
261}
262
263fn required_string<'a>(record: &'a DbRecord, field: &str) -> Result<&'a str, OpenAuthError> {
264 match record.get(field) {
265 Some(DbValue::String(value)) => Ok(value),
266 Some(_) => Err(invalid_field(field, "string")),
267 None => Err(missing_field(field)),
268 }
269}
270
271fn optional_string_field(record: &DbRecord, field: &str) -> Result<Option<String>, OpenAuthError> {
272 match record.get(field) {
273 Some(DbValue::String(value)) => Ok(Some(value.to_owned())),
274 Some(DbValue::Null) | None => Ok(None),
275 Some(_) => Err(invalid_field(field, "string or null")),
276 }
277}
278
279fn required_timestamp(record: &DbRecord, field: &str) -> Result<OffsetDateTime, OpenAuthError> {
280 match record.get(field) {
281 Some(DbValue::Timestamp(value)) => Ok(*value),
282 Some(_) => Err(invalid_field(field, "timestamp")),
283 None => Err(missing_field(field)),
284 }
285}
286
287fn optional_timestamp(
288 record: &DbRecord,
289 field: &str,
290) -> Result<Option<OffsetDateTime>, OpenAuthError> {
291 match record.get(field) {
292 Some(DbValue::Timestamp(value)) => Ok(Some(*value)),
293 Some(DbValue::Null) | None => Ok(None),
294 Some(_) => Err(invalid_field(field, "timestamp or null")),
295 }
296}
297
298fn optional_number(record: &DbRecord, field: &str) -> Result<Option<i64>, OpenAuthError> {
299 match record.get(field) {
300 Some(DbValue::Number(value)) => Ok(Some(*value)),
301 Some(DbValue::Null) | None => Ok(None),
302 Some(_) => Err(invalid_field(field, "number or null")),
303 }
304}
305
306fn missing_field(field: &str) -> OpenAuthError {
307 OpenAuthError::Adapter(format!("device code record is missing `{field}`"))
308}
309
310fn invalid_field(field: &str, expected: &str) -> OpenAuthError {
311 OpenAuthError::Adapter(format!(
312 "device code record field `{field}` must be {expected}"
313 ))
314}