Skip to main content

openclaw_gateway/auth/
middleware.rs

1//! Authentication middleware for axum.
2
3use std::sync::Arc;
4
5use axum::{
6    Json,
7    extract::{FromRef, FromRequestParts},
8    http::{StatusCode, header::AUTHORIZATION, request::Parts},
9    response::{IntoResponse, Response},
10};
11use serde::Serialize;
12use tokio::sync::RwLock;
13
14use super::AuthError;
15use super::config::AuthConfig;
16use super::jwt::{Claims, JwtManager};
17use super::setup::BootstrapManager;
18use super::users::{UserRole, UserStore};
19
20/// Shared authentication state.
21pub struct AuthState {
22    /// Auth configuration.
23    pub config: AuthConfig,
24    /// JWT manager.
25    pub jwt: JwtManager,
26    /// User store.
27    pub users: UserStore,
28    /// Bootstrap manager.
29    pub bootstrap: RwLock<BootstrapManager>,
30}
31
32impl AuthState {
33    /// Create a new auth state.
34    #[must_use]
35    pub fn new(config: AuthConfig, jwt: JwtManager, users: UserStore) -> Self {
36        Self {
37            config,
38            jwt,
39            users,
40            bootstrap: RwLock::new(BootstrapManager::new()),
41        }
42    }
43
44    /// Initialize auth state, auto-generating JWT secret if needed.
45    ///
46    /// # Errors
47    ///
48    /// Returns error if initialization fails.
49    pub fn initialize(
50        mut config: AuthConfig,
51        data_dir: &std::path::Path,
52    ) -> Result<Self, AuthError> {
53        // Open user store
54        let users = UserStore::open(data_dir)?;
55
56        // Generate or load JWT secret
57        let jwt_secret = if let Some(secret) = &config.jwt_secret {
58            secret.clone()
59        } else {
60            let secret = JwtManager::generate_hex_secret();
61            config.jwt_secret = Some(secret.clone());
62            // In a real implementation, we'd persist this to config
63            tracing::info!("Generated new JWT secret");
64            secret
65        };
66
67        let jwt = JwtManager::from_hex_secret(
68            &jwt_secret,
69            config.token_expiry(),
70            config.refresh_expiry(),
71        )?;
72
73        Ok(Self::new(config, jwt, users))
74    }
75
76    /// Check if auth is required for a method.
77    #[must_use]
78    pub fn requires_auth(&self, method: &str) -> bool {
79        self.config.enabled && !self.config.is_public_method(method)
80    }
81
82    /// Validate a token and return claims.
83    ///
84    /// # Errors
85    ///
86    /// Returns error if token is invalid.
87    pub fn validate_token(&self, token: &str) -> Result<Claims, AuthError> {
88        self.jwt.validate_access_token(token)
89    }
90}
91
92impl std::fmt::Debug for AuthState {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        f.debug_struct("AuthState")
95            .field("config", &self.config)
96            .field("user_count", &self.users.count())
97            .finish_non_exhaustive()
98    }
99}
100
101/// Auth layer marker for protected routes.
102#[derive(Debug, Clone)]
103pub struct AuthLayer;
104
105/// Extractor for authenticated requests.
106///
107/// Use this in handler parameters to require authentication.
108#[derive(Debug, Clone)]
109pub struct RequireAuth {
110    /// The authenticated user's claims.
111    pub claims: Claims,
112}
113
114impl RequireAuth {
115    /// Get the user ID.
116    #[must_use]
117    pub fn user_id(&self) -> &str {
118        &self.claims.sub
119    }
120
121    /// Get the username.
122    #[must_use]
123    pub fn username(&self) -> &str {
124        &self.claims.username
125    }
126
127    /// Get the user role.
128    #[must_use]
129    pub const fn role(&self) -> UserRole {
130        self.claims.role
131    }
132
133    /// Check if user is admin.
134    #[must_use]
135    pub fn is_admin(&self) -> bool {
136        self.claims.role.is_admin()
137    }
138
139    /// Require admin role.
140    ///
141    /// # Errors
142    ///
143    /// Returns error if user is not admin.
144    pub fn require_admin(&self) -> Result<(), AuthError> {
145        if self.is_admin() {
146            Ok(())
147        } else {
148            Err(AuthError::PermissionDenied(
149                "Admin role required".to_string(),
150            ))
151        }
152    }
153}
154
155/// Error response for auth failures.
156#[derive(Debug, Serialize)]
157struct AuthErrorResponse {
158    error: String,
159    code: &'static str,
160}
161
162impl IntoResponse for AuthError {
163    fn into_response(self) -> Response {
164        let (status, code) = match &self {
165            Self::InvalidCredentials => (StatusCode::UNAUTHORIZED, "invalid_credentials"),
166            Self::TokenError(_) => (StatusCode::UNAUTHORIZED, "invalid_token"),
167            Self::PermissionDenied(_) => (StatusCode::FORBIDDEN, "permission_denied"),
168            Self::SetupRequired => (StatusCode::SERVICE_UNAVAILABLE, "setup_required"),
169            Self::InvalidBootstrapToken => (StatusCode::UNAUTHORIZED, "invalid_bootstrap_token"),
170            Self::UserNotFound(_) => (StatusCode::NOT_FOUND, "user_not_found"),
171            Self::UserExists(_) => (StatusCode::CONFLICT, "user_exists"),
172            Self::Storage(_) | Self::Config(_) => {
173                (StatusCode::INTERNAL_SERVER_ERROR, "internal_error")
174            }
175        };
176
177        let body = AuthErrorResponse {
178            error: self.to_string(),
179            code,
180        };
181
182        (status, Json(body)).into_response()
183    }
184}
185
186/// Extractor implementation for `RequireAuth`.
187impl<S> FromRequestParts<S> for RequireAuth
188where
189    S: Send + Sync,
190    Arc<AuthState>: FromRef<S>,
191{
192    type Rejection = Response;
193
194    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
195        let auth_state = Arc::<AuthState>::from_ref(state);
196        extract_auth(parts, &auth_state).await
197    }
198}
199
200async fn extract_auth(parts: &Parts, auth_state: &AuthState) -> Result<RequireAuth, Response> {
201    // Check if auth is disabled
202    if !auth_state.config.enabled {
203        // Return a dummy admin claim when auth is disabled
204        return Ok(RequireAuth {
205            claims: Claims {
206                sub: "system".to_string(),
207                username: "system".to_string(),
208                role: UserRole::Admin,
209                iat: 0,
210                exp: i64::MAX,
211                token_type: super::jwt::TokenType::Access,
212                family_id: None,
213            },
214        });
215    }
216
217    // Extract token from Authorization header
218    let auth_header = parts
219        .headers
220        .get(AUTHORIZATION)
221        .and_then(|v| v.to_str().ok())
222        .ok_or_else(|| {
223            AuthError::TokenError("Missing Authorization header".to_string()).into_response()
224        })?;
225
226    let token = JwtManager::extract_from_header(auth_header).ok_or_else(|| {
227        AuthError::TokenError("Invalid Authorization header format".to_string()).into_response()
228    })?;
229
230    // Validate token
231    let claims = auth_state
232        .validate_token(token)
233        .map_err(IntoResponse::into_response)?;
234
235    // Check if user is still active
236    let user = auth_state
237        .users
238        .get(&claims.sub)
239        .map_err(IntoResponse::into_response)?
240        .ok_or_else(|| AuthError::UserNotFound(claims.sub.clone()).into_response())?;
241
242    if !user.active {
243        return Err(AuthError::PermissionDenied("Account disabled".to_string()).into_response());
244    }
245
246    Ok(RequireAuth { claims })
247}
248
249/// Extractor for optional authentication.
250///
251/// Returns `None` if no valid auth is present, `Some(RequireAuth)` otherwise.
252#[derive(Debug, Clone)]
253pub struct OptionalAuth(pub Option<RequireAuth>);
254
255impl<S> FromRequestParts<S> for OptionalAuth
256where
257    S: Send + Sync,
258    Arc<AuthState>: FromRef<S>,
259{
260    type Rejection = std::convert::Infallible;
261
262    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
263        Ok(Self(
264            RequireAuth::from_request_parts(parts, state).await.ok(),
265        ))
266    }
267}
268
269/// Require admin role extractor.
270#[derive(Debug, Clone)]
271pub struct RequireAdmin(pub RequireAuth);
272
273impl<S> FromRequestParts<S> for RequireAdmin
274where
275    S: Send + Sync,
276    Arc<AuthState>: FromRef<S>,
277{
278    type Rejection = Response;
279
280    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
281        let auth = RequireAuth::from_request_parts(parts, state).await?;
282
283        if !auth.is_admin() {
284            return Err(
285                AuthError::PermissionDenied("Admin role required".to_string()).into_response(),
286            );
287        }
288
289        Ok(Self(auth))
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296
297    #[test]
298    fn test_require_auth_methods() {
299        let auth = RequireAuth {
300            claims: Claims {
301                sub: "user_123".to_string(),
302                username: "testuser".to_string(),
303                role: UserRole::Operator,
304                iat: 0,
305                exp: i64::MAX,
306                token_type: super::super::jwt::TokenType::Access,
307                family_id: None,
308            },
309        };
310
311        assert_eq!(auth.user_id(), "user_123");
312        assert_eq!(auth.username(), "testuser");
313        assert_eq!(auth.role(), UserRole::Operator);
314        assert!(!auth.is_admin());
315        assert!(auth.require_admin().is_err());
316    }
317
318    #[test]
319    fn test_admin_auth() {
320        let auth = RequireAuth {
321            claims: Claims {
322                sub: "admin_1".to_string(),
323                username: "admin".to_string(),
324                role: UserRole::Admin,
325                iat: 0,
326                exp: i64::MAX,
327                token_type: super::super::jwt::TokenType::Access,
328                family_id: None,
329            },
330        };
331
332        assert!(auth.is_admin());
333        assert!(auth.require_admin().is_ok());
334    }
335}