batuta/serve/banco/
auth.rs1use axum::{
9 body::Body,
10 http::{Request, Response, StatusCode},
11 middleware::Next,
12};
13use serde::{Deserialize, Serialize};
14use std::collections::HashSet;
15use std::sync::Arc;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum AuthMode {
20 Local,
22 ApiKey,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ApiKey {
29 pub key: String,
30 pub scope: KeyScope,
31 pub created: u64,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36#[serde(rename_all = "lowercase")]
37pub enum KeyScope {
38 Chat,
40 Train,
42 Admin,
44}
45
46impl KeyScope {
47 #[must_use]
49 pub fn allows_path(&self, path: &str) -> bool {
50 match self {
51 Self::Admin => true,
52 Self::Train => {
53 !path.contains("/models/load")
54 && !path.contains("/models/unload")
55 && !path.contains("/config")
56 }
57 Self::Chat => {
58 let is_model_admin = path.contains("/models/load")
60 || path.contains("/models/unload")
61 || path.contains("/models/merge")
62 || path.contains("/models/pull")
63 || path.contains("/models/registry");
64
65 !is_model_admin
66 && (path.contains("/chat")
67 || path.contains("/models")
68 || path.contains("/health")
69 || path.contains("/system")
70 || path.contains("/tokenize")
71 || path.contains("/detokenize")
72 || path.contains("/embeddings")
73 || path.contains("/prompts")
74 || path.contains("/conversations")
75 || path.contains("/tags")
76 || path.contains("/show")
77 || path.contains("/completions"))
78 }
79 }
80 }
81}
82
83#[derive(Clone)]
85pub struct AuthStore {
86 mode: AuthMode,
87 keys: Arc<std::sync::RwLock<HashSet<String>>>,
88 key_details: Arc<std::sync::RwLock<Vec<ApiKey>>>,
89}
90
91impl AuthStore {
92 #[must_use]
94 pub fn local() -> Self {
95 Self {
96 mode: AuthMode::Local,
97 keys: Arc::new(std::sync::RwLock::new(HashSet::new())),
98 key_details: Arc::new(std::sync::RwLock::new(Vec::new())),
99 }
100 }
101
102 #[must_use]
104 pub fn api_key_mode() -> (Self, String) {
105 let key = generate_key();
106 let mut keys = HashSet::new();
107 keys.insert(key.clone());
108
109 let detail = ApiKey { key: key.clone(), scope: KeyScope::Admin, created: epoch_secs() };
110
111 let store = Self {
112 mode: AuthMode::ApiKey,
113 keys: Arc::new(std::sync::RwLock::new(keys)),
114 key_details: Arc::new(std::sync::RwLock::new(vec![detail])),
115 };
116 (store, key)
117 }
118
119 #[must_use]
121 pub fn requires_auth(&self) -> bool {
122 self.mode == AuthMode::ApiKey
123 }
124
125 #[must_use]
127 pub fn validate(&self, token: &str) -> bool {
128 if self.mode == AuthMode::Local {
129 return true;
130 }
131 self.keys.read().map(|k| k.contains(token)).unwrap_or(false)
132 }
133
134 #[must_use]
136 pub fn validate_for_path(&self, token: &str, path: &str) -> bool {
137 if self.mode == AuthMode::Local {
138 return true;
139 }
140 if let Ok(details) = self.key_details.read() {
141 if let Some(key) = details.iter().find(|k| k.key == token) {
142 return key.scope.allows_path(path);
143 }
144 }
145 false
146 }
147
148 pub fn generate_scoped_key(&self, scope: KeyScope) -> ApiKey {
150 let key_str = generate_key();
151 let api_key = ApiKey { key: key_str.clone(), scope, created: epoch_secs() };
152 if let Ok(mut keys) = self.keys.write() {
153 keys.insert(key_str);
154 }
155 if let Ok(mut details) = self.key_details.write() {
156 details.push(api_key.clone());
157 }
158 api_key
159 }
160
161 #[must_use]
163 pub fn list_keys(&self) -> Vec<ApiKeyInfo> {
164 self.key_details
165 .read()
166 .map(|details| {
167 details
168 .iter()
169 .map(|k| ApiKeyInfo {
170 prefix: k.key.chars().take(10).collect::<String>() + "...",
171 scope: k.scope,
172 created: k.created,
173 })
174 .collect()
175 })
176 .unwrap_or_default()
177 }
178
179 #[must_use]
181 pub fn mode(&self) -> AuthMode {
182 self.mode
183 }
184
185 #[must_use]
187 pub fn key_count(&self) -> usize {
188 self.keys.read().map(|k| k.len()).unwrap_or(0)
189 }
190}
191
192#[derive(Debug, Clone, Serialize)]
194pub struct ApiKeyInfo {
195 pub prefix: String,
196 pub scope: KeyScope,
197 pub created: u64,
198}
199
200fn generate_key() -> String {
202 use std::collections::hash_map::DefaultHasher;
203 use std::hash::{Hash, Hasher};
204 use std::sync::atomic::{AtomicU64, Ordering};
205
206 static COUNTER: AtomicU64 = AtomicU64::new(0);
207 let seq = COUNTER.fetch_add(1, Ordering::SeqCst);
208
209 let mut hasher = DefaultHasher::new();
210 epoch_secs().hash(&mut hasher);
211 std::process::id().hash(&mut hasher);
212 seq.hash(&mut hasher);
213 let hash = hasher.finish();
214 format!("bk_{hash:016x}")
215}
216
217fn epoch_secs() -> u64 {
218 std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs()
219}
220
221pub async fn auth_layer(
223 auth_store: AuthStore,
224 request: Request<Body>,
225 next: Next,
226) -> Result<Response<Body>, StatusCode> {
227 if !auth_store.requires_auth() {
228 return Ok(next.run(request).await);
229 }
230
231 let path_str = request.uri().path();
233 if path_str.starts_with("/health") {
234 return Ok(next.run(request).await);
235 }
236
237 let token = request
239 .headers()
240 .get("authorization")
241 .and_then(|v| v.to_str().ok())
242 .and_then(|v| v.strip_prefix("Bearer "));
243
244 let path = request.uri().path().to_string();
245 match token {
246 Some(t) if auth_store.validate_for_path(t, &path) => Ok(next.run(request).await),
247 Some(_) => Err(StatusCode::FORBIDDEN), _ => Err(StatusCode::UNAUTHORIZED),
249 }
250}