Skip to main content

batuta/serve/banco/
auth.rs

1//! API key authentication for Banco.
2//!
3//! - Local mode (127.0.0.1): no auth required
4//! - LAN mode (0.0.0.0): API key required via `Authorization: Bearer bk_...`
5//!
6//! Keys are generated at startup and stored in `~/.banco/keys.toml`.
7
8use axum::{
9    body::Body,
10    http::{Request, Response, StatusCode},
11    middleware::Next,
12};
13use serde::{Deserialize, Serialize};
14use std::collections::HashSet;
15use std::sync::Arc;
16
17/// Authentication mode.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum AuthMode {
20    /// No authentication (localhost only).
21    Local,
22    /// API key required (LAN/remote access).
23    ApiKey,
24}
25
26/// API key with scope.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ApiKey {
29    pub key: String,
30    pub scope: KeyScope,
31    pub created: u64,
32}
33
34/// Key scope controls which endpoints are accessible.
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36#[serde(rename_all = "lowercase")]
37pub enum KeyScope {
38    /// Chat completions, models list, health.
39    Chat,
40    /// All of Chat + training, data, experiments.
41    Train,
42    /// All of Train + model load/unload, system config.
43    Admin,
44}
45
46impl KeyScope {
47    /// Check if this scope allows access to the given path.
48    #[must_use]
49    pub fn allows_path(&self, path: &str) -> bool {
50        match self {
51            Self::Admin => true,
52            Self::Train => {
53                !path.contains("/models/load")
54                    && !path.contains("/models/unload")
55                    && !path.contains("/config")
56            }
57            Self::Chat => {
58                // Chat allows read-only model info but NOT load/unload/merge/pull
59                let is_model_admin = path.contains("/models/load")
60                    || path.contains("/models/unload")
61                    || path.contains("/models/merge")
62                    || path.contains("/models/pull")
63                    || path.contains("/models/registry");
64
65                !is_model_admin
66                    && (path.contains("/chat")
67                        || path.contains("/models")
68                        || path.contains("/health")
69                        || path.contains("/system")
70                        || path.contains("/tokenize")
71                        || path.contains("/detokenize")
72                        || path.contains("/embeddings")
73                        || path.contains("/prompts")
74                        || path.contains("/conversations")
75                        || path.contains("/tags")
76                        || path.contains("/show")
77                        || path.contains("/completions"))
78            }
79        }
80    }
81}
82
83/// API key store.
84#[derive(Clone)]
85pub struct AuthStore {
86    mode: AuthMode,
87    keys: Arc<std::sync::RwLock<HashSet<String>>>,
88    key_details: Arc<std::sync::RwLock<Vec<ApiKey>>>,
89}
90
91impl AuthStore {
92    /// Create a store in local mode (no auth).
93    #[must_use]
94    pub fn local() -> Self {
95        Self {
96            mode: AuthMode::Local,
97            keys: Arc::new(std::sync::RwLock::new(HashSet::new())),
98            key_details: Arc::new(std::sync::RwLock::new(Vec::new())),
99        }
100    }
101
102    /// Create a store in API key mode with a generated key.
103    #[must_use]
104    pub fn api_key_mode() -> (Self, String) {
105        let key = generate_key();
106        let mut keys = HashSet::new();
107        keys.insert(key.clone());
108
109        let detail = ApiKey { key: key.clone(), scope: KeyScope::Admin, created: epoch_secs() };
110
111        let store = Self {
112            mode: AuthMode::ApiKey,
113            keys: Arc::new(std::sync::RwLock::new(keys)),
114            key_details: Arc::new(std::sync::RwLock::new(vec![detail])),
115        };
116        (store, key)
117    }
118
119    /// Check if auth is required.
120    #[must_use]
121    pub fn requires_auth(&self) -> bool {
122        self.mode == AuthMode::ApiKey
123    }
124
125    /// Validate a bearer token.
126    #[must_use]
127    pub fn validate(&self, token: &str) -> bool {
128        if self.mode == AuthMode::Local {
129            return true;
130        }
131        self.keys.read().map(|k| k.contains(token)).unwrap_or(false)
132    }
133
134    /// Validate a token against a specific path (scope-aware).
135    #[must_use]
136    pub fn validate_for_path(&self, token: &str, path: &str) -> bool {
137        if self.mode == AuthMode::Local {
138            return true;
139        }
140        if let Ok(details) = self.key_details.read() {
141            if let Some(key) = details.iter().find(|k| k.key == token) {
142                return key.scope.allows_path(path);
143            }
144        }
145        false
146    }
147
148    /// Generate a new scoped API key.
149    pub fn generate_scoped_key(&self, scope: KeyScope) -> ApiKey {
150        let key_str = generate_key();
151        let api_key = ApiKey { key: key_str.clone(), scope, created: epoch_secs() };
152        if let Ok(mut keys) = self.keys.write() {
153            keys.insert(key_str);
154        }
155        if let Ok(mut details) = self.key_details.write() {
156            details.push(api_key.clone());
157        }
158        api_key
159    }
160
161    /// List all keys (redacted).
162    #[must_use]
163    pub fn list_keys(&self) -> Vec<ApiKeyInfo> {
164        self.key_details
165            .read()
166            .map(|details| {
167                details
168                    .iter()
169                    .map(|k| ApiKeyInfo {
170                        prefix: k.key.chars().take(10).collect::<String>() + "...",
171                        scope: k.scope,
172                        created: k.created,
173                    })
174                    .collect()
175            })
176            .unwrap_or_default()
177    }
178
179    /// Get the auth mode.
180    #[must_use]
181    pub fn mode(&self) -> AuthMode {
182        self.mode
183    }
184
185    /// Number of registered keys.
186    #[must_use]
187    pub fn key_count(&self) -> usize {
188        self.keys.read().map(|k| k.len()).unwrap_or(0)
189    }
190}
191
192/// Redacted key info for listing.
193#[derive(Debug, Clone, Serialize)]
194pub struct ApiKeyInfo {
195    pub prefix: String,
196    pub scope: KeyScope,
197    pub created: u64,
198}
199
200/// Generate a random API key with `bk_` prefix.
201fn generate_key() -> String {
202    use std::collections::hash_map::DefaultHasher;
203    use std::hash::{Hash, Hasher};
204    use std::sync::atomic::{AtomicU64, Ordering};
205
206    static COUNTER: AtomicU64 = AtomicU64::new(0);
207    let seq = COUNTER.fetch_add(1, Ordering::SeqCst);
208
209    let mut hasher = DefaultHasher::new();
210    epoch_secs().hash(&mut hasher);
211    std::process::id().hash(&mut hasher);
212    seq.hash(&mut hasher);
213    let hash = hasher.finish();
214    format!("bk_{hash:016x}")
215}
216
217fn epoch_secs() -> u64 {
218    std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs()
219}
220
221/// Axum middleware: enforce API key auth when in ApiKey mode.
222pub async fn auth_layer(
223    auth_store: AuthStore,
224    request: Request<Body>,
225    next: Next,
226) -> Result<Response<Body>, StatusCode> {
227    if !auth_store.requires_auth() {
228        return Ok(next.run(request).await);
229    }
230
231    // Health/probe endpoints always accessible (for load balancer probes)
232    let path_str = request.uri().path();
233    if path_str.starts_with("/health") {
234        return Ok(next.run(request).await);
235    }
236
237    // Extract Bearer token
238    let token = request
239        .headers()
240        .get("authorization")
241        .and_then(|v| v.to_str().ok())
242        .and_then(|v| v.strip_prefix("Bearer "));
243
244    let path = request.uri().path().to_string();
245    match token {
246        Some(t) if auth_store.validate_for_path(t, &path) => Ok(next.run(request).await),
247        Some(_) => Err(StatusCode::FORBIDDEN), // valid key, wrong scope
248        _ => Err(StatusCode::UNAUTHORIZED),
249    }
250}