Skip to main content

guts_compat/
middleware.rs

1//! HTTP middleware for GitHub API compatibility.
2
3use serde::{Deserialize, Serialize};
4
5use crate::pagination::PaginationLinks;
6use crate::rate_limit::{RateLimitHeaders, RateLimitResource, RateLimitState};
7use crate::token::TokenScope;
8use crate::user::UserId;
9
10/// Authentication context extracted from request.
11#[derive(Debug, Clone)]
12pub struct AuthContext {
13    /// User ID if authenticated.
14    pub user_id: Option<UserId>,
15    /// Username if authenticated.
16    pub username: Option<String>,
17    /// Token scopes if authenticated via token.
18    pub scopes: Vec<TokenScope>,
19    /// Whether the user is authenticated.
20    pub authenticated: bool,
21    /// Client IP for rate limiting.
22    pub client_ip: String,
23}
24
25impl AuthContext {
26    /// Create an unauthenticated context.
27    pub fn anonymous(client_ip: String) -> Self {
28        Self {
29            user_id: None,
30            username: None,
31            scopes: Vec::new(),
32            authenticated: false,
33            client_ip,
34        }
35    }
36
37    /// Create an authenticated context.
38    pub fn authenticated(
39        user_id: UserId,
40        username: String,
41        scopes: Vec<TokenScope>,
42        client_ip: String,
43    ) -> Self {
44        Self {
45            user_id: Some(user_id),
46            username: Some(username),
47            scopes,
48            authenticated: true,
49            client_ip,
50        }
51    }
52
53    /// Check if a scope is granted.
54    pub fn has_scope(&self, scope: TokenScope) -> bool {
55        if !self.authenticated {
56            return false;
57        }
58
59        // Admin scope grants all
60        if self.scopes.contains(&TokenScope::Admin) {
61            return true;
62        }
63
64        self.scopes.contains(&scope)
65    }
66
67    /// Get the rate limit key (user ID or IP).
68    pub fn rate_limit_key(&self) -> String {
69        if let Some(id) = self.user_id {
70            format!("user:{}", id)
71        } else {
72            format!("ip:{}", self.client_ip)
73        }
74    }
75}
76
77/// GitHub-compatible error response.
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct ErrorResponse {
80    /// Error message.
81    pub message: String,
82    /// Documentation URL.
83    #[serde(skip_serializing_if = "Option::is_none")]
84    pub documentation_url: Option<String>,
85    /// Validation errors (for 422 responses).
86    #[serde(skip_serializing_if = "Option::is_none")]
87    pub errors: Option<Vec<ValidationError>>,
88}
89
90impl ErrorResponse {
91    /// Create a simple error response.
92    pub fn new(message: impl Into<String>) -> Self {
93        Self {
94            message: message.into(),
95            documentation_url: None,
96            errors: None,
97        }
98    }
99
100    /// Create an error with documentation URL.
101    pub fn with_docs(message: impl Into<String>, docs_url: impl Into<String>) -> Self {
102        Self {
103            message: message.into(),
104            documentation_url: Some(docs_url.into()),
105            errors: None,
106        }
107    }
108
109    /// Create a validation error response.
110    pub fn validation(message: impl Into<String>, errors: Vec<ValidationError>) -> Self {
111        Self {
112            message: message.into(),
113            documentation_url: None,
114            errors: Some(errors),
115        }
116    }
117
118    /// Standard "Not Found" error.
119    pub fn not_found() -> Self {
120        Self::new("Not Found")
121    }
122
123    /// Standard "Bad credentials" error.
124    pub fn bad_credentials() -> Self {
125        Self::new("Bad credentials")
126    }
127
128    /// Standard "Forbidden" error.
129    pub fn forbidden() -> Self {
130        Self::new("Forbidden")
131    }
132
133    /// Standard rate limit error.
134    pub fn rate_limited(reset: u64) -> Self {
135        Self::with_docs(
136            format!(
137                "API rate limit exceeded. Rate limit will reset at {}",
138                reset
139            ),
140            "https://docs.guts.network/rest/rate-limiting",
141        )
142    }
143}
144
145/// Validation error detail.
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct ValidationError {
148    /// Resource type.
149    pub resource: String,
150    /// Field name.
151    pub field: String,
152    /// Error code.
153    pub code: ValidationErrorCode,
154    /// Additional message.
155    #[serde(skip_serializing_if = "Option::is_none")]
156    pub message: Option<String>,
157}
158
159impl ValidationError {
160    /// Create a new validation error.
161    pub fn new(
162        resource: impl Into<String>,
163        field: impl Into<String>,
164        code: ValidationErrorCode,
165    ) -> Self {
166        Self {
167            resource: resource.into(),
168            field: field.into(),
169            code,
170            message: None,
171        }
172    }
173
174    /// Create with a message.
175    pub fn with_message(
176        resource: impl Into<String>,
177        field: impl Into<String>,
178        code: ValidationErrorCode,
179        message: impl Into<String>,
180    ) -> Self {
181        Self {
182            resource: resource.into(),
183            field: field.into(),
184            code,
185            message: Some(message.into()),
186        }
187    }
188}
189
190/// Validation error codes.
191#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
192#[serde(rename_all = "snake_case")]
193pub enum ValidationErrorCode {
194    /// Field is missing.
195    Missing,
196    /// Field value is missing (null).
197    MissingField,
198    /// Field value is invalid.
199    Invalid,
200    /// Resource already exists.
201    AlreadyExists,
202    /// Value is not unique.
203    NotUnique,
204    /// Value is too long.
205    TooLong,
206    /// Value is too short.
207    TooShort,
208    /// Custom error.
209    Custom,
210}
211
212/// Response headers builder.
213#[derive(Debug, Clone, Default)]
214pub struct ResponseHeaders {
215    /// Rate limit headers.
216    pub rate_limit: Option<RateLimitHeaders>,
217    /// Pagination Link header.
218    pub link: Option<String>,
219    /// ETag header.
220    pub etag: Option<String>,
221    /// Last-Modified header.
222    pub last_modified: Option<String>,
223    /// Cache-Control header.
224    pub cache_control: Option<String>,
225}
226
227impl ResponseHeaders {
228    /// Create a new response headers builder.
229    pub fn new() -> Self {
230        Self::default()
231    }
232
233    /// Add rate limit headers.
234    pub fn with_rate_limit(mut self, state: &RateLimitState) -> Self {
235        self.rate_limit = Some(RateLimitHeaders::from(state));
236        self
237    }
238
239    /// Add pagination Link header.
240    pub fn with_pagination(mut self, links: &PaginationLinks) -> Self {
241        self.link = links.to_header_value();
242        self
243    }
244
245    /// Add ETag header.
246    pub fn with_etag(mut self, etag: impl Into<String>) -> Self {
247        self.etag = Some(format!("\"{}\"", etag.into()));
248        self
249    }
250
251    /// Add cache control.
252    pub fn with_cache_control(mut self, value: impl Into<String>) -> Self {
253        self.cache_control = Some(value.into());
254        self
255    }
256
257    /// No cache.
258    pub fn no_cache(mut self) -> Self {
259        self.cache_control = Some("private, max-age=60, s-maxage=60".to_string());
260        self
261    }
262}
263
264/// Parse Authorization header.
265///
266/// Supports:
267/// - `Bearer <token>` - Personal access token
268/// - `Basic <base64>` - username:token as password
269/// - `token <token>` - GitHub-style token header
270pub fn parse_authorization_header(header: &str) -> Option<AuthorizationValue> {
271    let header = header.trim();
272
273    if let Some(token) = header.strip_prefix("Bearer ") {
274        return Some(AuthorizationValue::Bearer(token.trim().to_string()));
275    }
276
277    if let Some(token) = header.strip_prefix("token ") {
278        return Some(AuthorizationValue::Token(token.trim().to_string()));
279    }
280
281    if let Some(encoded) = header.strip_prefix("Basic ") {
282        if let Some((username, password)) = decode_basic_auth(encoded.trim()) {
283            return Some(AuthorizationValue::Basic { username, password });
284        }
285    }
286
287    None
288}
289
290/// Authorization header value.
291#[derive(Debug, Clone)]
292pub enum AuthorizationValue {
293    /// Bearer token.
294    Bearer(String),
295    /// Token (GitHub-style).
296    Token(String),
297    /// Basic auth (username:password).
298    Basic { username: String, password: String },
299}
300
301impl AuthorizationValue {
302    /// Get the token string regardless of format.
303    pub fn token(&self) -> Option<&str> {
304        match self {
305            Self::Bearer(t) | Self::Token(t) => Some(t),
306            Self::Basic { password, .. } => {
307                // In Basic auth, the token is used as the password
308                if password.starts_with("guts_") {
309                    Some(password)
310                } else {
311                    None
312                }
313            }
314        }
315    }
316
317    /// Get the username for Basic auth.
318    pub fn username(&self) -> Option<&str> {
319        match self {
320            Self::Basic { username, .. } => Some(username),
321            _ => None,
322        }
323    }
324}
325
326/// Decode Basic auth header value.
327fn decode_basic_auth(encoded: &str) -> Option<(String, String)> {
328    // Simple base64 decode
329    let decoded = base64_decode(encoded)?;
330    let text = String::from_utf8(decoded).ok()?;
331
332    let (username, password) = text.split_once(':')?;
333    Some((username.to_string(), password.to_string()))
334}
335
336/// Base64 decode.
337fn base64_decode(input: &str) -> Option<Vec<u8>> {
338    const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
339
340    fn char_to_value(c: u8) -> Option<u8> {
341        if let Some(pos) = ALPHABET.iter().position(|&x| x == c) {
342            Some(pos as u8)
343        } else if c == b'=' {
344            Some(0)
345        } else {
346            None
347        }
348    }
349
350    let input = input.trim();
351    if input.is_empty() || !input.len().is_multiple_of(4) {
352        return None;
353    }
354
355    let bytes: Vec<u8> = input.bytes().collect();
356    let mut result = Vec::with_capacity(bytes.len() * 3 / 4);
357
358    for chunk in bytes.chunks(4) {
359        let a = char_to_value(chunk[0])?;
360        let b = char_to_value(chunk[1])?;
361        let c = char_to_value(chunk[2])?;
362        let d = char_to_value(chunk[3])?;
363
364        result.push((a << 2) | (b >> 4));
365
366        if chunk[2] != b'=' {
367            result.push((b << 4) | (c >> 2));
368        }
369        if chunk[3] != b'=' {
370            result.push((c << 6) | d);
371        }
372    }
373
374    Some(result)
375}
376
377/// Get resource type from request path.
378pub fn resource_from_path(path: &str) -> RateLimitResource {
379    if path.contains("/search") {
380        RateLimitResource::Search
381    } else if path.contains("/graphql") {
382        RateLimitResource::Graphql
383    } else if path.starts_with("/git/") {
384        RateLimitResource::Git
385    } else {
386        RateLimitResource::Core
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393
394    #[test]
395    fn test_auth_context_anonymous() {
396        let ctx = AuthContext::anonymous("127.0.0.1".to_string());
397
398        assert!(!ctx.authenticated);
399        assert!(ctx.user_id.is_none());
400        assert!(!ctx.has_scope(TokenScope::RepoRead));
401    }
402
403    #[test]
404    fn test_auth_context_authenticated() {
405        let ctx = AuthContext::authenticated(
406            1,
407            "alice".to_string(),
408            vec![TokenScope::RepoRead],
409            "127.0.0.1".to_string(),
410        );
411
412        assert!(ctx.authenticated);
413        assert_eq!(ctx.user_id, Some(1));
414        assert!(ctx.has_scope(TokenScope::RepoRead));
415        assert!(!ctx.has_scope(TokenScope::RepoWrite));
416    }
417
418    #[test]
419    fn test_rate_limit_key() {
420        let anon = AuthContext::anonymous("10.0.0.1".to_string());
421        assert_eq!(anon.rate_limit_key(), "ip:10.0.0.1");
422
423        let auth =
424            AuthContext::authenticated(42, "bob".to_string(), vec![], "10.0.0.1".to_string());
425        assert_eq!(auth.rate_limit_key(), "user:42");
426    }
427
428    #[test]
429    fn test_parse_authorization_bearer() {
430        let auth = parse_authorization_header("Bearer guts_abc12345_secret").unwrap();
431        match auth {
432            AuthorizationValue::Bearer(token) => {
433                assert_eq!(token, "guts_abc12345_secret");
434            }
435            _ => panic!("Expected Bearer"),
436        }
437    }
438
439    #[test]
440    fn test_parse_authorization_token() {
441        let auth = parse_authorization_header("token guts_abc12345_secret").unwrap();
442        match auth {
443            AuthorizationValue::Token(token) => {
444                assert_eq!(token, "guts_abc12345_secret");
445            }
446            _ => panic!("Expected Token"),
447        }
448    }
449
450    #[test]
451    fn test_parse_authorization_basic() {
452        // "user:pass" in base64 = "dXNlcjpwYXNz"
453        let auth = parse_authorization_header("Basic dXNlcjpwYXNz").unwrap();
454        match auth {
455            AuthorizationValue::Basic { username, password } => {
456                assert_eq!(username, "user");
457                assert_eq!(password, "pass");
458            }
459            _ => panic!("Expected Basic"),
460        }
461    }
462
463    #[test]
464    fn test_error_response() {
465        let err = ErrorResponse::not_found();
466        assert_eq!(err.message, "Not Found");
467        assert!(err.errors.is_none());
468    }
469
470    #[test]
471    fn test_validation_error() {
472        let err = ValidationError::new("User", "username", ValidationErrorCode::AlreadyExists);
473        assert_eq!(err.resource, "User");
474        assert_eq!(err.field, "username");
475    }
476
477    #[test]
478    fn test_resource_from_path() {
479        assert_eq!(
480            resource_from_path("/api/search/repositories"),
481            RateLimitResource::Search
482        );
483        assert_eq!(
484            resource_from_path("/git/owner/repo/info/refs"),
485            RateLimitResource::Git
486        );
487        assert_eq!(
488            resource_from_path("/api/repos/owner/repo"),
489            RateLimitResource::Core
490        );
491    }
492}