1use base64ct::{Base64UrlUnpadded, Encoding};
2use chrono::{DateTime, Utc};
3use rand::TryRngCore;
4use rand::rngs::OsRng;
5use serde::Serialize;
6use url::Url;
7
8use crate::db::Db;
9use crate::error::AuthError;
10use crate::types::{
11 AccentInk, ApplicationId, ClientId, ClientSecret, ClientType, Mode, PasswordHash,
12 SplashPrimitive, UserId,
13};
14
15#[derive(Debug, Clone, Serialize, sqlx::FromRow)]
20pub struct Application {
21 pub id: ApplicationId,
22 pub name: String,
23 pub client_id: ClientId,
24 pub client_type: ClientType,
25 #[serde(skip_serializing)]
26 pub client_secret_hash: Option<PasswordHash>,
27 pub redirect_uris: String, pub logo_url: Option<String>,
29 pub primary_color: Option<String>,
30 pub accent_hex: Option<String>,
32 pub accent_ink: Option<AccentInk>,
33 pub forced_mode: Option<Mode>,
34 pub font_css_url: Option<String>,
35 pub font_family: Option<String>,
36 pub splash_text: Option<String>,
37 pub splash_image_url: Option<String>,
38 pub splash_primitive: Option<SplashPrimitive>,
39 pub splash_url: Option<String>,
40 pub shader_cell_scale: Option<i64>,
41 pub is_trusted: bool,
42 pub created_by: Option<UserId>,
43 pub is_active: bool,
44 pub created_at: DateTime<Utc>,
45 pub updated_at: DateTime<Utc>,
46}
47
48#[derive(Debug, Clone, Serialize, sqlx::FromRow)]
57pub struct BrandingConfig {
58 pub application_name: String,
59 pub logo_url: Option<String>,
60 pub primary_color: Option<String>,
61 pub accent_hex: Option<String>,
62 pub accent_ink: Option<AccentInk>,
63 pub forced_mode: Option<Mode>,
64 pub font_css_url: Option<String>,
65 pub font_family: Option<String>,
66 pub splash_text: Option<String>,
67 pub splash_image_url: Option<String>,
68 pub splash_primitive: Option<SplashPrimitive>,
69 pub splash_url: Option<String>,
70 pub shader_cell_scale: Option<i64>,
71 #[sqlx(skip)]
74 pub title_brand: Option<String>,
75}
76
77pub fn generate_client_id() -> ClientId {
82 let mut bytes = [0u8; 24];
83 OsRng
84 .try_fill_bytes(&mut bytes)
85 .expect("OS RNG unavailable");
86 let encoded = Base64UrlUnpadded::encode_string(&bytes);
87 ClientId::new_unchecked(format!("ath_{encoded}"))
88}
89
90pub fn generate_client_secret() -> Result<(ClientSecret, PasswordHash), AuthError> {
97 let mut bytes = [0u8; 32];
98 OsRng
99 .try_fill_bytes(&mut bytes)
100 .expect("OS RNG unavailable");
101 let raw = Base64UrlUnpadded::encode_string(&bytes);
102 let hash = crate::password::hash_password(&raw)?;
103 Ok((ClientSecret::new_unchecked(raw), hash))
104}
105
106impl Application {
107 pub fn redirect_uri_list(&self) -> Result<Vec<String>, AuthError> {
113 serde_json::from_str(&self.redirect_uris)
114 .map_err(|e| AuthError::Database(sqlx::Error::Decode(Box::new(e))))
115 }
116
117 pub fn branding(&self) -> BrandingConfig {
119 BrandingConfig {
120 application_name: self.name.clone(),
121 logo_url: self.logo_url.clone(),
122 primary_color: self.primary_color.clone(),
123 accent_hex: self.accent_hex.clone(),
124 accent_ink: self.accent_ink,
125 forced_mode: self.forced_mode,
126 font_css_url: self.font_css_url.clone(),
127 font_family: self.font_family.clone(),
128 splash_text: self.splash_text.clone(),
129 splash_image_url: self.splash_image_url.clone(),
130 splash_primitive: self.splash_primitive,
131 splash_url: self.splash_url.clone(),
132 shader_cell_scale: self.shader_cell_scale,
133 title_brand: None,
134 }
135 }
136}
137
138impl BrandingConfig {
139 pub fn new(application_name: impl Into<String>) -> Self {
143 Self {
144 application_name: application_name.into(),
145 logo_url: None,
146 primary_color: None,
147 accent_hex: None,
148 accent_ink: None,
149 forced_mode: None,
150 font_css_url: None,
151 font_family: None,
152 splash_text: None,
153 splash_image_url: None,
154 splash_primitive: None,
155 splash_url: None,
156 shader_cell_scale: None,
157 title_brand: None,
158 }
159 }
160
161 pub fn with_accent(mut self, hex: impl Into<String>, ink: AccentInk) -> Self {
162 self.accent_hex = Some(hex.into());
163 self.accent_ink = Some(ink);
164 self
165 }
166
167 pub fn with_primary_color(mut self, hex: impl Into<String>) -> Self {
168 self.primary_color = Some(hex.into());
169 self
170 }
171
172 pub fn with_logo_url(mut self, url: impl Into<String>) -> Self {
173 self.logo_url = Some(url.into());
174 self
175 }
176
177 pub fn with_splash_text(mut self, text: impl Into<String>) -> Self {
178 self.splash_text = Some(text.into());
179 self
180 }
181
182 pub fn with_splash_image_url(mut self, url: impl Into<String>) -> Self {
183 self.splash_image_url = Some(url.into());
184 self
185 }
186
187 pub fn with_splash_primitive(mut self, primitive: SplashPrimitive) -> Self {
188 self.splash_primitive = Some(primitive);
189 self
190 }
191
192 pub fn with_shader_cell_scale(mut self, scale: i64) -> Self {
193 self.shader_cell_scale = Some(scale);
194 self
195 }
196
197 pub fn with_title_brand(mut self, brand: impl Into<String>) -> Self {
198 self.title_brand = Some(brand.into());
199 self
200 }
201}
202
203fn map_unique_violation(err: sqlx::Error) -> AuthError {
204 if let sqlx::Error::Database(ref db_err) = err {
205 let msg = db_err.message();
206 if msg.contains("UNIQUE constraint failed") && msg.contains("client_id") {
207 return AuthError::Conflict("client_id already exists".into());
208 }
209 }
210 AuthError::Database(err)
211}
212
213pub struct ApplicationCursor {
218 pub created_at: DateTime<Utc>,
219 pub id: ApplicationId,
220}
221
222#[derive(serde::Serialize, serde::Deserialize)]
223struct RawCursor {
224 ca: String,
225 id: String,
226}
227
228impl ApplicationCursor {
229 pub fn from_app(app: &Application) -> Self {
230 Self {
231 created_at: app.created_at,
232 id: app.id,
233 }
234 }
235
236 pub fn encode(&self) -> String {
237 let raw = RawCursor {
238 ca: self.created_at.to_rfc3339(),
239 id: self.id.to_string(),
240 };
241 let json = serde_json::to_string(&raw).expect("RawCursor serializes");
242 Base64UrlUnpadded::encode_string(json.as_bytes())
243 }
244
245 pub fn decode(s: &str) -> Option<Self> {
246 let bytes = Base64UrlUnpadded::decode_vec(s).ok()?;
247 let raw: RawCursor = serde_json::from_slice(&bytes).ok()?;
248 let created_at = chrono::DateTime::parse_from_rfc3339(&raw.ca)
249 .ok()?
250 .with_timezone(&Utc);
251 let id = raw
252 .id
253 .parse::<uuid::Uuid>()
254 .ok()
255 .map(ApplicationId::from_uuid)?;
256 Some(Self { created_at, id })
257 }
258}
259
260pub struct CreateApplicationParams {
262 pub name: String,
263 pub client_type: ClientType,
264 pub redirect_uris: Vec<String>,
265 pub is_trusted: bool,
266 pub created_by: Option<UserId>,
267 pub logo_url: Option<String>,
268 pub primary_color: Option<String>,
269 pub accent_hex: Option<String>,
270 pub accent_ink: Option<AccentInk>,
271 pub forced_mode: Option<Mode>,
272 pub font_css_url: Option<String>,
273 pub font_family: Option<String>,
274 pub splash_text: Option<String>,
275 pub splash_image_url: Option<String>,
276 pub splash_primitive: Option<SplashPrimitive>,
277 pub splash_url: Option<String>,
278 pub shader_cell_scale: Option<i64>,
279}
280
281pub struct UpdateApplication {
286 pub name: String,
287 pub redirect_uris: Vec<String>,
288 pub is_trusted: bool,
289 pub is_active: bool,
290 pub logo_url: Option<String>,
291 pub primary_color: Option<String>,
292 pub accent_hex: Option<String>,
293 pub accent_ink: Option<AccentInk>,
294 pub forced_mode: Option<Mode>,
295 pub font_css_url: Option<String>,
296 pub font_family: Option<String>,
297 pub splash_text: Option<String>,
298 pub splash_image_url: Option<String>,
299 pub splash_primitive: Option<SplashPrimitive>,
300 pub splash_url: Option<String>,
301 pub shader_cell_scale: Option<i64>,
302}
303
304pub fn validate_redirect_uris(uris: &[String]) -> Result<(), AuthError> {
315 if uris.is_empty() {
316 return Err(AuthError::InvalidRedirectUri(
317 "redirect_uris must not be empty".into(),
318 ));
319 }
320 for uri in uris {
321 let parsed = Url::parse(uri).map_err(|_| AuthError::InvalidRedirectUri(uri.clone()))?;
322 if parsed.fragment().is_some() {
323 return Err(AuthError::InvalidRedirectUri(uri.clone()));
324 }
325 let scheme = parsed.scheme();
326 if scheme == "https" {
327 continue;
328 }
329 if scheme == "http" {
330 let host = parsed.host_str().unwrap_or("");
331 if host == "localhost" || host == "127.0.0.1" {
332 continue;
333 }
334 }
335 return Err(AuthError::InvalidRedirectUri(uri.clone()));
336 }
337 Ok(())
338}
339
340pub fn validate_redirect_uri(redirect_uri: &str, registered: &[String]) -> Result<(), AuthError> {
347 if registered.iter().any(|r| r == redirect_uri) {
348 Ok(())
349 } else {
350 Err(AuthError::InvalidRedirectUri(redirect_uri.to_owned()))
351 }
352}
353
354pub fn validate_logo_url(url: &str) -> Result<(), AuthError> {
360 let parsed = Url::parse(url)
361 .map_err(|_| AuthError::Validation("logo_url must be a valid absolute URL".into()))?;
362 let scheme = parsed.scheme();
363 if scheme == "https" {
364 return Ok(());
365 }
366 if scheme == "http" {
367 let host = parsed.host_str().unwrap_or("");
368 if host == "localhost" || host == "127.0.0.1" {
369 return Ok(());
370 }
371 }
372 Err(AuthError::Validation(
373 "logo_url must be an HTTPS URL".into(),
374 ))
375}
376
377pub fn validate_font_css_url(url: &str) -> Result<(), AuthError> {
380 validate_https_url(url, "font_css_url")
381}
382
383pub fn validate_splash_image_url(url: &str) -> Result<(), AuthError> {
385 validate_https_url(url, "splash_image_url")
386}
387
388pub fn validate_splash_url(url: &str) -> Result<(), AuthError> {
390 validate_https_url(url, "splash_url")
391}
392
393fn validate_https_url(url: &str, field: &str) -> Result<(), AuthError> {
398 let parsed = Url::parse(url)
399 .map_err(|_| AuthError::Validation(format!("{field} must be a valid absolute URL")))?;
400 if parsed.scheme() != "https" {
401 return Err(AuthError::Validation(format!(
402 "{field} must be an HTTPS URL"
403 )));
404 }
405 Ok(())
406}
407
408fn validate_hex_color(color: &str, field: &str) -> Result<(), AuthError> {
413 let bytes = color.as_bytes();
414 if bytes.len() != 7 || bytes[0] != b'#' || !bytes[1..].iter().all(|b| b.is_ascii_hexdigit()) {
415 return Err(AuthError::Validation(format!(
416 "{field} must be a hex color (#RRGGBB)"
417 )));
418 }
419 Ok(())
420}
421
422pub fn validate_primary_color(color: &str) -> Result<(), AuthError> {
428 validate_hex_color(color, "primary_color")
429}
430
431pub fn validate_accent_hex(color: &str) -> Result<(), AuthError> {
437 validate_hex_color(color, "accent_hex")
438}
439
440impl Db {
441 pub async fn create_application(
451 &self,
452 params: CreateApplicationParams,
453 ) -> Result<(Application, Option<ClientSecret>), AuthError> {
454 let CreateApplicationParams {
455 name,
456 client_type,
457 redirect_uris,
458 is_trusted,
459 created_by,
460 logo_url,
461 primary_color,
462 accent_hex,
463 accent_ink,
464 forced_mode,
465 font_css_url,
466 font_family,
467 splash_text,
468 splash_image_url,
469 splash_primitive,
470 splash_url,
471 shader_cell_scale,
472 } = params;
473 validate_redirect_uris(&redirect_uris)?;
474 if let Some(ref url) = logo_url {
475 validate_logo_url(url)?;
476 }
477 if let Some(ref color) = primary_color {
478 validate_primary_color(color)?;
479 }
480 if let Some(ref hex) = accent_hex {
481 validate_accent_hex(hex)?;
482 }
483 if let Some(ref url) = font_css_url {
484 validate_font_css_url(url)?;
485 }
486 if let Some(ref url) = splash_image_url {
487 validate_splash_image_url(url)?;
488 }
489 if let Some(ref url) = splash_url {
490 validate_splash_url(url)?;
491 }
492 let id = ApplicationId::new();
493 let client_id = generate_client_id();
494 let (raw_secret, hash) = match client_type {
495 ClientType::Confidential => {
496 let (secret, hash) = generate_client_secret()?;
497 (Some(secret), Some(hash))
498 }
499 ClientType::Public => (None, None),
500 };
501 let redirect_uris_json =
502 serde_json::to_string(&redirect_uris).expect("Vec<String> serializes to JSON");
503 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
504
505 sqlx::query(
506 "INSERT INTO allowthem_applications \
507 (id, name, client_id, client_type, client_secret_hash, redirect_uris, logo_url, \
508 primary_color, \
509 accent_hex, accent_ink, forced_mode, font_css_url, font_family, \
510 splash_text, splash_image_url, splash_primitive, splash_url, shader_cell_scale, \
511 is_trusted, created_by, is_active, created_at, updated_at) \
512 VALUES \
513 (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, \
514 ?9, ?10, ?11, ?12, ?13, \
515 ?14, ?15, ?16, ?17, ?18, \
516 ?19, ?20, 1, ?21, ?21)",
517 )
518 .bind(id)
519 .bind(&name)
520 .bind(&client_id)
521 .bind(client_type)
522 .bind(&hash)
523 .bind(&redirect_uris_json)
524 .bind(&logo_url)
525 .bind(&primary_color)
526 .bind(&accent_hex)
527 .bind(accent_ink.map(|v| v.as_str()))
528 .bind(forced_mode.map(|v| v.as_str()))
529 .bind(&font_css_url)
530 .bind(&font_family)
531 .bind(&splash_text)
532 .bind(&splash_image_url)
533 .bind(splash_primitive.map(|v| v.as_str()))
534 .bind(&splash_url)
535 .bind(shader_cell_scale)
536 .bind(is_trusted)
537 .bind(created_by)
538 .bind(&now)
539 .execute(self.pool())
540 .await
541 .map_err(map_unique_violation)?;
542
543 let app = self.get_application(id).await?;
544 Ok((app, raw_secret))
545 }
546
547 pub async fn get_application(&self, id: ApplicationId) -> Result<Application, AuthError> {
549 sqlx::query_as::<_, Application>(
550 "SELECT id, name, client_id, client_type, client_secret_hash, redirect_uris, \
551 logo_url, primary_color, \
552 accent_hex, accent_ink, forced_mode, font_css_url, font_family, \
553 splash_text, splash_image_url, splash_primitive, splash_url, shader_cell_scale, \
554 is_trusted, created_by, is_active, created_at, updated_at \
555 FROM allowthem_applications WHERE id = ?",
556 )
557 .bind(id)
558 .fetch_optional(self.pool())
559 .await?
560 .ok_or(AuthError::NotFound)
561 }
562
563 pub async fn get_application_by_client_id(
567 &self,
568 client_id: &ClientId,
569 ) -> Result<Application, AuthError> {
570 sqlx::query_as::<_, Application>(
571 "SELECT id, name, client_id, client_type, client_secret_hash, redirect_uris, \
572 logo_url, primary_color, \
573 accent_hex, accent_ink, forced_mode, font_css_url, font_family, \
574 splash_text, splash_image_url, splash_primitive, splash_url, shader_cell_scale, \
575 is_trusted, created_by, is_active, created_at, updated_at \
576 FROM allowthem_applications WHERE client_id = ?",
577 )
578 .bind(client_id)
579 .fetch_optional(self.pool())
580 .await?
581 .ok_or(AuthError::NotFound)
582 }
583
584 pub async fn get_branding_by_client_id(
590 &self,
591 client_id: &ClientId,
592 ) -> Result<Option<BrandingConfig>, AuthError> {
593 sqlx::query_as::<_, BrandingConfig>(
594 "SELECT name AS application_name, logo_url, primary_color, \
595 accent_hex, accent_ink, forced_mode, font_css_url, font_family, \
596 splash_text, splash_image_url, splash_primitive, splash_url, shader_cell_scale \
597 FROM allowthem_applications \
598 WHERE client_id = ? AND is_active = 1",
599 )
600 .bind(client_id)
601 .fetch_optional(self.pool())
602 .await
603 .map_err(AuthError::Database)
604 }
605
606 pub async fn count_applications(&self) -> Result<u64, AuthError> {
609 let n: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM allowthem_applications")
610 .fetch_one(self.pool())
611 .await
612 .map_err(AuthError::Database)?;
613 Ok(n as u64)
614 }
615
616 pub async fn count_users_for_application(&self, id: ApplicationId) -> Result<u64, AuthError> {
623 let count: i64 =
624 sqlx::query_scalar("SELECT COUNT(*) FROM allowthem_consents WHERE application_id = ?1")
625 .bind(id)
626 .fetch_one(self.pool())
627 .await
628 .map_err(AuthError::Database)?;
629 Ok(count as u64)
630 }
631
632 pub async fn list_applications(&self) -> Result<Vec<Application>, AuthError> {
634 sqlx::query_as::<_, Application>(
635 "SELECT id, name, client_id, client_type, client_secret_hash, redirect_uris, \
636 logo_url, primary_color, \
637 accent_hex, accent_ink, forced_mode, font_css_url, font_family, \
638 splash_text, splash_image_url, splash_primitive, splash_url, shader_cell_scale, \
639 is_trusted, created_by, is_active, created_at, updated_at \
640 FROM allowthem_applications ORDER BY created_at ASC",
641 )
642 .fetch_all(self.pool())
643 .await
644 .map_err(AuthError::Database)
645 }
646
647 pub async fn list_applications_paginated(
651 &self,
652 limit: u32,
653 cursor: Option<&ApplicationCursor>,
654 ) -> Result<Vec<Application>, AuthError> {
655 let limit = (limit as i64).min(200);
656 match cursor {
657 None => sqlx::query_as::<_, Application>(
658 "SELECT id, name, client_id, client_type, client_secret_hash, \
659 redirect_uris, logo_url, primary_color, \
660 accent_hex, accent_ink, forced_mode, font_css_url, font_family, \
661 splash_text, splash_image_url, splash_primitive, splash_url, shader_cell_scale, \
662 is_trusted, created_by, is_active, created_at, updated_at \
663 FROM allowthem_applications \
664 ORDER BY created_at ASC, id ASC LIMIT ?1",
665 )
666 .bind(limit)
667 .fetch_all(self.pool())
668 .await
669 .map_err(AuthError::Database),
670 Some(cur) => {
671 let ca = cur.created_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
674 sqlx::query_as::<_, Application>(
675 "SELECT id, name, client_id, client_type, client_secret_hash, \
676 redirect_uris, logo_url, primary_color, \
677 accent_hex, accent_ink, forced_mode, font_css_url, font_family, \
678 splash_text, splash_image_url, splash_primitive, splash_url, shader_cell_scale, \
679 is_trusted, created_by, is_active, created_at, updated_at \
680 FROM allowthem_applications \
681 WHERE (created_at > ?1 OR (created_at = ?1 AND id > ?2)) \
682 ORDER BY created_at ASC, id ASC LIMIT ?3",
683 )
684 .bind(&ca)
685 .bind(cur.id)
686 .bind(limit)
687 .fetch_all(self.pool())
688 .await
689 .map_err(AuthError::Database)
690 }
691 }
692 }
693
694 pub async fn update_application(
703 &self,
704 id: ApplicationId,
705 params: UpdateApplication,
706 ) -> Result<(), AuthError> {
707 validate_redirect_uris(¶ms.redirect_uris)?;
708 if let Some(ref url) = params.logo_url {
709 validate_logo_url(url)?;
710 }
711 if let Some(ref color) = params.primary_color {
712 validate_primary_color(color)?;
713 }
714 if let Some(ref hex) = params.accent_hex {
715 validate_accent_hex(hex)?;
716 }
717 if let Some(ref url) = params.font_css_url {
718 validate_font_css_url(url)?;
719 }
720 if let Some(ref url) = params.splash_image_url {
721 validate_splash_image_url(url)?;
722 }
723 if let Some(ref url) = params.splash_url {
724 validate_splash_url(url)?;
725 }
726 let redirect_uris_json =
727 serde_json::to_string(¶ms.redirect_uris).expect("Vec<String> serializes to JSON");
728 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
729
730 let result = sqlx::query(
731 "UPDATE allowthem_applications \
732 SET name = ?1, redirect_uris = ?2, is_trusted = ?3, is_active = ?4, \
733 logo_url = ?5, primary_color = ?6, \
734 accent_hex = ?7, accent_ink = ?8, forced_mode = ?9, \
735 font_css_url = ?10, font_family = ?11, \
736 splash_text = ?12, splash_image_url = ?13, splash_primitive = ?14, \
737 splash_url = ?15, shader_cell_scale = ?16, \
738 updated_at = ?17 \
739 WHERE id = ?18",
740 )
741 .bind(¶ms.name)
742 .bind(&redirect_uris_json)
743 .bind(params.is_trusted)
744 .bind(params.is_active)
745 .bind(¶ms.logo_url)
746 .bind(¶ms.primary_color)
747 .bind(¶ms.accent_hex)
748 .bind(params.accent_ink.map(|v| v.as_str()))
749 .bind(params.forced_mode.map(|v| v.as_str()))
750 .bind(¶ms.font_css_url)
751 .bind(¶ms.font_family)
752 .bind(¶ms.splash_text)
753 .bind(¶ms.splash_image_url)
754 .bind(params.splash_primitive.map(|v| v.as_str()))
755 .bind(¶ms.splash_url)
756 .bind(params.shader_cell_scale)
757 .bind(&now)
758 .bind(id)
759 .execute(self.pool())
760 .await?;
761
762 if result.rows_affected() == 0 {
763 return Err(AuthError::NotFound);
764 }
765 Ok(())
766 }
767
768 pub async fn regenerate_client_secret(
776 &self,
777 id: ApplicationId,
778 ) -> Result<(Application, ClientSecret), AuthError> {
779 let application = self.get_application(id).await?;
780 if application.client_type == ClientType::Public {
781 return Err(AuthError::InvalidRequest(
782 "public clients have no client secret".into(),
783 ));
784 }
785 let (raw_secret, hash) = generate_client_secret()?;
786 let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
787
788 let result = sqlx::query(
789 "UPDATE allowthem_applications \
790 SET client_secret_hash = ?1, updated_at = ?2 \
791 WHERE id = ?3",
792 )
793 .bind(&hash)
794 .bind(&now)
795 .bind(id)
796 .execute(self.pool())
797 .await?;
798
799 if result.rows_affected() == 0 {
800 return Err(AuthError::NotFound);
801 }
802
803 let app = self.get_application(id).await?;
804 Ok((app, raw_secret))
805 }
806
807 pub async fn delete_application(&self, id: ApplicationId) -> Result<(), AuthError> {
812 let result = sqlx::query("DELETE FROM allowthem_applications WHERE id = ?")
813 .bind(id)
814 .execute(self.pool())
815 .await?;
816
817 if result.rows_affected() == 0 {
818 return Err(AuthError::NotFound);
819 }
820 Ok(())
821 }
822}
823
824#[cfg(test)]
825mod tests {
826 use super::*;
827 use crate::password::verify_password;
828 use crate::types::ApplicationId;
829
830 #[test]
831 fn client_id_has_ath_prefix() {
832 let id = generate_client_id();
833 assert!(
834 id.as_str().starts_with("ath_"),
835 "client_id must start with ath_"
836 );
837 }
838
839 #[test]
840 fn client_id_length_is_36() {
841 let id = generate_client_id();
842 assert_eq!(id.as_str().len(), 36, "ath_(4) + 32 base64url chars = 36");
843 }
844
845 #[test]
846 fn client_id_chars_are_url_safe() {
847 let id = generate_client_id();
848 let suffix = &id.as_str()[4..];
850 assert!(
851 suffix
852 .chars()
853 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
854 "client_id suffix must be URL-safe base64url: got {suffix}"
855 );
856 }
857
858 #[test]
859 fn two_client_ids_differ() {
860 let a = generate_client_id();
861 let b = generate_client_id();
862 assert_ne!(a, b, "each client_id must be unique");
863 }
864
865 #[test]
866 fn client_secret_verifies_round_trip() {
867 let (secret, hash) = generate_client_secret().expect("generate_client_secret");
868 let valid = verify_password(secret.as_str(), &hash).expect("verify_password");
869 assert!(valid, "generated secret must verify against its own hash");
870 }
871
872 #[test]
873 fn two_client_secrets_differ() {
874 let (s1, _) = generate_client_secret().expect("secret 1");
875 let (s2, _) = generate_client_secret().expect("secret 2");
876 assert_ne!(s1.as_str(), s2.as_str(), "each secret must be unique");
877 }
878
879 #[test]
880 fn wrong_secret_does_not_verify() {
881 let (_, hash) = generate_client_secret().expect("generate_client_secret");
882 let valid = verify_password("wrong-secret", &hash).expect("verify_password");
883 assert!(!valid, "wrong secret must not verify");
884 }
885
886 #[test]
889 fn redirect_uri_empty_list_is_rejected() {
890 let err = validate_redirect_uris(&[]).unwrap_err();
891 assert!(matches!(err, AuthError::InvalidRedirectUri(_)));
892 }
893
894 #[test]
895 fn redirect_uri_https_is_valid() {
896 let uris = vec!["https://example.com/callback".to_string()];
897 assert!(validate_redirect_uris(&uris).is_ok());
898 }
899
900 #[test]
901 fn redirect_uri_http_localhost_is_valid() {
902 let uris = vec!["http://localhost/callback".to_string()];
903 assert!(validate_redirect_uris(&uris).is_ok());
904 }
905
906 #[test]
907 fn redirect_uri_http_localhost_with_port_is_valid() {
908 let uris = vec!["http://localhost:3000/callback".to_string()];
909 assert!(validate_redirect_uris(&uris).is_ok());
910 }
911
912 #[test]
913 fn redirect_uri_http_127_0_0_1_is_valid() {
914 let uris = vec!["http://127.0.0.1:8080/callback".to_string()];
915 assert!(validate_redirect_uris(&uris).is_ok());
916 }
917
918 #[test]
919 fn redirect_uri_http_non_localhost_is_rejected() {
920 let uris = vec!["http://example.com/callback".to_string()];
921 let err = validate_redirect_uris(&uris).unwrap_err();
922 assert!(matches!(err, AuthError::InvalidRedirectUri(_)));
923 }
924
925 #[test]
926 fn redirect_uri_with_fragment_is_rejected() {
927 let uris = vec!["https://example.com/callback#section".to_string()];
928 let err = validate_redirect_uris(&uris).unwrap_err();
929 assert!(matches!(err, AuthError::InvalidRedirectUri(_)));
930 }
931
932 #[test]
933 fn redirect_uri_relative_is_rejected() {
934 let uris = vec!["/callback".to_string()];
935 let err = validate_redirect_uris(&uris).unwrap_err();
936 assert!(matches!(err, AuthError::InvalidRedirectUri(_)));
937 }
938
939 #[test]
942 fn redirect_uri_exact_match_passes() {
943 let registered = vec!["https://example.com/callback".to_string()];
944 assert!(validate_redirect_uri("https://example.com/callback", ®istered).is_ok());
945 }
946
947 #[test]
948 fn redirect_uri_not_in_registered_is_rejected() {
949 let registered = vec!["https://example.com/callback".to_string()];
950 let err = validate_redirect_uri("https://example.com/other", ®istered).unwrap_err();
951 assert!(matches!(err, AuthError::InvalidRedirectUri(_)));
952 }
953
954 #[test]
957 fn redirect_uri_list_parses_valid_json() {
958 let (_, hash) = generate_client_secret().expect("generate_client_secret");
959 let app = Application {
960 id: ApplicationId::new(),
961 name: "Test".to_string(),
962 client_id: generate_client_id(),
963 client_type: ClientType::Confidential,
964 client_secret_hash: Some(hash),
965 redirect_uris: r#"["https://example.com/callback","https://example.com/other"]"#
966 .to_string(),
967 logo_url: None,
968 primary_color: None,
969 accent_hex: None,
970 accent_ink: None,
971 forced_mode: None,
972 font_css_url: None,
973 font_family: None,
974 splash_text: None,
975 splash_image_url: None,
976 splash_primitive: None,
977 splash_url: None,
978 shader_cell_scale: None,
979 is_trusted: false,
980 created_by: None,
981 is_active: true,
982 created_at: chrono::Utc::now(),
983 updated_at: chrono::Utc::now(),
984 };
985 let list = app.redirect_uri_list().expect("redirect_uri_list");
986 assert_eq!(
987 list,
988 vec![
989 "https://example.com/callback".to_string(),
990 "https://example.com/other".to_string(),
991 ]
992 );
993 }
994
995 #[test]
996 fn redirect_uri_list_returns_error_on_malformed_json() {
997 let (_, hash) = generate_client_secret().expect("generate_client_secret");
998 let app = Application {
999 id: ApplicationId::new(),
1000 name: "Test".to_string(),
1001 client_id: generate_client_id(),
1002 client_type: ClientType::Confidential,
1003 client_secret_hash: Some(hash),
1004 redirect_uris: "not valid json".to_string(),
1005 logo_url: None,
1006 primary_color: None,
1007 accent_hex: None,
1008 accent_ink: None,
1009 forced_mode: None,
1010 font_css_url: None,
1011 font_family: None,
1012 splash_text: None,
1013 splash_image_url: None,
1014 splash_primitive: None,
1015 splash_url: None,
1016 shader_cell_scale: None,
1017 is_trusted: false,
1018 created_by: None,
1019 is_active: true,
1020 created_at: chrono::Utc::now(),
1021 updated_at: chrono::Utc::now(),
1022 };
1023 assert!(matches!(
1024 app.redirect_uri_list(),
1025 Err(AuthError::Database(_))
1026 ));
1027 }
1028
1029 #[test]
1032 fn logo_url_https_is_valid() {
1033 assert!(validate_logo_url("https://example.com/logo.png").is_ok());
1034 }
1035
1036 #[test]
1037 fn logo_url_http_localhost_is_valid() {
1038 assert!(validate_logo_url("http://localhost:3000/logo.png").is_ok());
1039 }
1040
1041 #[test]
1042 fn logo_url_http_127_is_valid() {
1043 assert!(validate_logo_url("http://127.0.0.1:8080/logo.png").is_ok());
1044 }
1045
1046 #[test]
1047 fn logo_url_http_non_localhost_is_rejected() {
1048 let err = validate_logo_url("http://example.com/logo.png").unwrap_err();
1049 assert!(matches!(err, AuthError::Validation(_)));
1050 }
1051
1052 #[test]
1053 fn logo_url_relative_is_rejected() {
1054 let err = validate_logo_url("/logo.png").unwrap_err();
1055 assert!(matches!(err, AuthError::Validation(_)));
1056 }
1057
1058 #[test]
1059 fn logo_url_not_a_url_is_rejected() {
1060 let err = validate_logo_url("not a url").unwrap_err();
1061 assert!(matches!(err, AuthError::Validation(_)));
1062 }
1063
1064 #[test]
1067 fn primary_color_valid_hex() {
1068 assert!(validate_primary_color("#3B82F6").is_ok());
1069 }
1070
1071 #[test]
1072 fn primary_color_lowercase_hex() {
1073 assert!(validate_primary_color("#3b82f6").is_ok());
1074 }
1075
1076 #[test]
1077 fn primary_color_missing_hash() {
1078 let err = validate_primary_color("3B82F6").unwrap_err();
1079 assert!(matches!(err, AuthError::Validation(_)));
1080 }
1081
1082 #[test]
1083 fn primary_color_too_short() {
1084 let err = validate_primary_color("#FFF").unwrap_err();
1085 assert!(matches!(err, AuthError::Validation(_)));
1086 }
1087
1088 #[test]
1089 fn primary_color_too_long() {
1090 let err = validate_primary_color("#3B82F6FF").unwrap_err();
1091 assert!(matches!(err, AuthError::Validation(_)));
1092 }
1093
1094 #[test]
1095 fn primary_color_non_hex_chars() {
1096 let err = validate_primary_color("#ZZZZZZ").unwrap_err();
1097 assert!(matches!(err, AuthError::Validation(_)));
1098 }
1099
1100 #[test]
1101 fn primary_color_named_color_rejected() {
1102 let err = validate_primary_color("red").unwrap_err();
1103 assert!(matches!(err, AuthError::Validation(_)));
1104 }
1105
1106 #[test]
1109 fn branding_extracts_correct_fields() {
1110 let (_, hash) = generate_client_secret().expect("generate");
1111 let app = Application {
1112 id: ApplicationId::new(),
1113 name: "My App".to_string(),
1114 client_id: generate_client_id(),
1115 client_type: ClientType::Confidential,
1116 client_secret_hash: Some(hash),
1117 redirect_uris: r#"["https://example.com/cb"]"#.to_string(),
1118 logo_url: Some("https://example.com/logo.png".to_string()),
1119 primary_color: Some("#3B82F6".to_string()),
1120 accent_hex: None,
1121 accent_ink: None,
1122 forced_mode: None,
1123 font_css_url: None,
1124 font_family: None,
1125 splash_text: None,
1126 splash_image_url: None,
1127 splash_primitive: None,
1128 splash_url: None,
1129 shader_cell_scale: None,
1130 is_trusted: false,
1131 created_by: None,
1132 is_active: true,
1133 created_at: chrono::Utc::now(),
1134 updated_at: chrono::Utc::now(),
1135 };
1136 let b = app.branding();
1137 assert_eq!(b.application_name, "My App");
1138 assert_eq!(b.logo_url.as_deref(), Some("https://example.com/logo.png"));
1139 assert_eq!(b.primary_color.as_deref(), Some("#3B82F6"));
1140 }
1141
1142 #[test]
1145 fn https_url_accepts_https() {
1146 assert!(validate_font_css_url("https://example.com/x.css").is_ok());
1147 }
1148
1149 #[test]
1150 fn https_url_rejects_http() {
1151 let err = validate_font_css_url("http://example.com/x.css").unwrap_err();
1152 assert!(matches!(err, AuthError::Validation(_)));
1153 }
1154
1155 #[test]
1156 fn https_url_rejects_invalid() {
1157 let err = validate_font_css_url("not a url").unwrap_err();
1158 assert!(matches!(err, AuthError::Validation(_)));
1159 }
1160
1161 #[test]
1162 fn logo_url_loopback_hostname_accepted() {
1163 assert!(validate_logo_url("http://localhost/logo.png").is_ok());
1164 }
1165
1166 #[test]
1167 fn logo_url_loopback_ip_accepted() {
1168 assert!(validate_logo_url("http://127.0.0.1/logo.png").is_ok());
1169 }
1170
1171 #[test]
1172 fn font_css_url_rejects_localhost() {
1173 let err = validate_font_css_url("http://localhost/font.css").unwrap_err();
1174 assert!(matches!(err, AuthError::Validation(_)));
1175 }
1176
1177 #[test]
1180 fn accent_hex_valid() {
1181 assert!(validate_accent_hex("#ff6b35").is_ok());
1182 }
1183
1184 #[test]
1185 fn accent_hex_rejects_named_color() {
1186 let err = validate_accent_hex("red").unwrap_err();
1187 assert!(matches!(err, AuthError::Validation(_)));
1188 }
1189
1190 #[test]
1191 fn accent_hex_rejects_shorthand() {
1192 let err = validate_accent_hex("#fff").unwrap_err();
1193 assert!(matches!(err, AuthError::Validation(_)));
1194 }
1195
1196 #[test]
1197 fn accent_hex_rejects_non_hex_chars() {
1198 let err = validate_accent_hex("#gggggg").unwrap_err();
1199 assert!(matches!(err, AuthError::Validation(_)));
1200 }
1201
1202 #[test]
1203 fn primary_color_still_valid_after_refactor() {
1204 assert!(validate_primary_color("#3B82F6").is_ok());
1205 }
1206
1207 #[test]
1208 fn application_serialization_omits_secret() {
1209 let (_, hash) = generate_client_secret().expect("generate_client_secret");
1210 let app = Application {
1211 id: ApplicationId::new(),
1212 name: "Test App".to_string(),
1213 client_id: generate_client_id(),
1214 client_type: ClientType::Confidential,
1215 client_secret_hash: Some(hash),
1216 redirect_uris: r#"["https://example.com/callback"]"#.to_string(),
1217 logo_url: None,
1218 primary_color: None,
1219 accent_hex: None,
1220 accent_ink: None,
1221 forced_mode: None,
1222 font_css_url: None,
1223 font_family: None,
1224 splash_text: None,
1225 splash_image_url: None,
1226 splash_primitive: None,
1227 splash_url: None,
1228 shader_cell_scale: None,
1229 is_trusted: false,
1230 created_by: None,
1231 is_active: true,
1232 created_at: chrono::Utc::now(),
1233 updated_at: chrono::Utc::now(),
1234 };
1235 let value = serde_json::to_value(&app).expect("serialize Application");
1236 assert!(
1237 value.get("client_secret_hash").is_none(),
1238 "client_secret_hash must not appear in serialized output"
1239 );
1240 assert!(
1241 value.get("client_id").is_some(),
1242 "client_id must appear in serialized output"
1243 );
1244 }
1245
1246 #[tokio::test]
1249 async fn count_users_for_application_returns_consent_count() {
1250 let db = crate::db::Db::connect("sqlite::memory:")
1251 .await
1252 .expect("in-memory db");
1253 let (app, _secret) = db
1254 .create_application(CreateApplicationParams {
1255 name: "Count Test".into(),
1256 client_type: ClientType::Confidential,
1257 redirect_uris: vec!["https://example.com/callback".into()],
1258 is_trusted: false,
1259 created_by: None,
1260 logo_url: None,
1261 primary_color: None,
1262 accent_hex: None,
1263 accent_ink: None,
1264 forced_mode: None,
1265 font_css_url: None,
1266 font_family: None,
1267 splash_text: None,
1268 splash_image_url: None,
1269 splash_primitive: None,
1270 splash_url: None,
1271 shader_cell_scale: None,
1272 })
1273 .await
1274 .expect("create_application");
1275
1276 let email1 = crate::Email::new("u1@test.com".into()).expect("email");
1277 let email2 = crate::Email::new("u2@test.com".into()).expect("email");
1278 let user1 = db
1279 .create_user(email1, "pw", None, None)
1280 .await
1281 .expect("user1");
1282 let user2 = db
1283 .create_user(email2, "pw", None, None)
1284 .await
1285 .expect("user2");
1286
1287 let id1 = uuid::Uuid::new_v4();
1288 let id2 = uuid::Uuid::new_v4();
1289 sqlx::query(
1290 "INSERT OR IGNORE INTO allowthem_consents (id, user_id, application_id) \
1291 VALUES (?, ?, ?)",
1292 )
1293 .bind(id1.to_string())
1294 .bind(user1.id)
1295 .bind(app.id)
1296 .execute(db.pool())
1297 .await
1298 .expect("insert consent 1");
1299 sqlx::query(
1300 "INSERT OR IGNORE INTO allowthem_consents (id, user_id, application_id) \
1301 VALUES (?, ?, ?)",
1302 )
1303 .bind(id2.to_string())
1304 .bind(user2.id)
1305 .bind(app.id)
1306 .execute(db.pool())
1307 .await
1308 .expect("insert consent 2");
1309
1310 let count = db.count_users_for_application(app.id).await.expect("count");
1311 assert_eq!(count, 2, "expected 2 consented users");
1312 }
1313
1314 #[tokio::test]
1315 async fn count_users_is_zero_for_unknown_application() {
1316 let db = crate::db::Db::connect("sqlite::memory:")
1317 .await
1318 .expect("in-memory db");
1319 let unknown_id = ApplicationId::new();
1320 let count = db
1321 .count_users_for_application(unknown_id)
1322 .await
1323 .expect("count for unknown app");
1324 assert_eq!(count, 0, "no consents for unknown application");
1325 }
1326
1327 async fn make_app(db: &crate::db::Db) -> Application {
1330 let (app, _) = db
1331 .create_application(CreateApplicationParams {
1332 name: "Test App".into(),
1333 client_type: crate::types::ClientType::Confidential,
1334 redirect_uris: vec!["https://example.com/callback".into()],
1335 is_trusted: false,
1336 created_by: None,
1337 logo_url: None,
1338 primary_color: None,
1339 accent_hex: None,
1340 accent_ink: None,
1341 forced_mode: None,
1342 font_css_url: None,
1343 font_family: None,
1344 splash_text: None,
1345 splash_image_url: None,
1346 splash_primitive: None,
1347 splash_url: None,
1348 shader_cell_scale: None,
1349 })
1350 .await
1351 .expect("create_application");
1352 app
1353 }
1354
1355 #[tokio::test]
1356 async fn count_applications_zero_on_empty_db() {
1357 let db = crate::db::Db::connect("sqlite::memory:")
1358 .await
1359 .expect("in-memory db");
1360 let n = db.count_applications().await.expect("count_applications");
1361 assert_eq!(n, 0);
1362 }
1363
1364 #[tokio::test]
1365 async fn count_applications_after_create() {
1366 let db = crate::db::Db::connect("sqlite::memory:")
1367 .await
1368 .expect("in-memory db");
1369 make_app(&db).await;
1370 make_app(&db).await;
1371 make_app(&db).await;
1372 let n = db.count_applications().await.expect("count_applications");
1373 assert_eq!(n, 3);
1374 }
1375
1376 #[cfg(test)]
1377 mod branding_config_builder_tests {
1378 use super::*;
1379 use crate::types::{AccentInk, SplashPrimitive};
1380
1381 #[test]
1382 fn new_sets_application_name_leaves_rest_none() {
1383 let b = BrandingConfig::new("Fixture Co");
1384 assert_eq!(b.application_name, "Fixture Co");
1385 assert!(b.logo_url.is_none());
1386 assert!(b.primary_color.is_none());
1387 assert!(b.accent_hex.is_none());
1388 assert!(b.accent_ink.is_none());
1389 assert!(b.forced_mode.is_none());
1390 assert!(b.font_css_url.is_none());
1391 assert!(b.font_family.is_none());
1392 assert!(b.splash_text.is_none());
1393 assert!(b.splash_image_url.is_none());
1394 assert!(b.splash_primitive.is_none());
1395 assert!(b.splash_url.is_none());
1396 assert!(b.shader_cell_scale.is_none());
1397 }
1398
1399 #[test]
1400 fn with_accent_sets_hex_and_ink() {
1401 let b = BrandingConfig::new("Co").with_accent("#ff7a1a", AccentInk::Black);
1402 assert_eq!(b.accent_hex.as_deref(), Some("#ff7a1a"));
1403 assert_eq!(b.accent_ink, Some(AccentInk::Black));
1404 }
1405
1406 #[test]
1407 fn with_splash_text_sets_field() {
1408 let b = BrandingConfig::new("Co").with_splash_text("TRANSFER");
1409 assert_eq!(b.splash_text.as_deref(), Some("TRANSFER"));
1410 }
1411
1412 #[test]
1413 fn with_shader_cell_scale_sets_field() {
1414 let b = BrandingConfig::new("Co").with_shader_cell_scale(18);
1415 assert_eq!(b.shader_cell_scale, Some(18));
1416 }
1417
1418 #[test]
1419 fn with_splash_primitive_sets_field() {
1420 let b = BrandingConfig::new("Co").with_splash_primitive(SplashPrimitive::Wave);
1421 assert_eq!(b.splash_primitive, Some(SplashPrimitive::Wave));
1422 }
1423
1424 #[test]
1425 fn with_logo_url_sets_field() {
1426 let b = BrandingConfig::new("Co").with_logo_url("https://cdn.example/logo.svg");
1427 assert_eq!(b.logo_url.as_deref(), Some("https://cdn.example/logo.svg"));
1428 }
1429
1430 #[test]
1431 fn with_primary_color_sets_field() {
1432 let b = BrandingConfig::new("Co").with_primary_color("#0066ff");
1433 assert_eq!(b.primary_color.as_deref(), Some("#0066ff"));
1434 }
1435
1436 #[test]
1437 fn with_splash_image_url_sets_field() {
1438 let b =
1439 BrandingConfig::new("Co").with_splash_image_url("https://cdn.example/splash.png");
1440 assert_eq!(
1441 b.splash_image_url.as_deref(),
1442 Some("https://cdn.example/splash.png")
1443 );
1444 }
1445 }
1446}