openai_protocol/
validated.rs

1// Validated JSON extractor for automatic request validation
2//
3// This module provides a ValidatedJson extractor that automatically validates
4// requests using the validator crate's Validate trait.
5
6/// Trait for request types that need post-deserialization normalization
7pub trait Normalizable {
8    /// Normalize the request by applying defaults and transformations
9    fn normalize(&mut self) {
10        // Default: no-op
11    }
12}
13
14#[cfg(feature = "axum")]
15use axum::{
16    extract::{rejection::JsonRejection, FromRequest, Request},
17    http::StatusCode,
18    response::{IntoResponse, Response},
19    Json,
20};
21#[cfg(feature = "axum")]
22use serde::de::DeserializeOwned;
23#[cfg(feature = "axum")]
24use serde_json::json;
25#[cfg(feature = "axum")]
26use validator::Validate;
27
28/// A JSON extractor that automatically validates and normalizes the request body
29///
30/// This extractor deserializes the request body and automatically calls `.validate()`
31/// on types that implement the `Validate` trait. If validation fails, it returns
32/// a 400 Bad Request with detailed error information.
33///
34/// # Example
35///
36/// ```rust,ignore
37/// async fn create_chat(
38///     ValidatedJson(request): ValidatedJson<ChatCompletionRequest>,
39/// ) -> Response {
40///     // request is guaranteed to be valid here
41///     process_request(request).await
42/// }
43/// ```
44#[cfg(feature = "axum")]
45pub struct ValidatedJson<T>(pub T);
46
47#[cfg(feature = "axum")]
48impl<S, T> FromRequest<S> for ValidatedJson<T>
49where
50    T: DeserializeOwned + Validate + Normalizable + Send,
51    S: Send + Sync,
52{
53    type Rejection = Response;
54
55    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
56        // First, extract and deserialize the JSON
57        let Json(mut data) =
58            Json::<T>::from_request(req, state)
59                .await
60                .map_err(|err: JsonRejection| {
61                    let error_message = match err {
62                        JsonRejection::JsonDataError(e) => {
63                            format!("Invalid JSON data: {}", e)
64                        }
65                        JsonRejection::JsonSyntaxError(e) => {
66                            format!("JSON syntax error: {}", e)
67                        }
68                        JsonRejection::MissingJsonContentType(_) => {
69                            "Missing Content-Type: application/json header".to_string()
70                        }
71                        _ => format!("Failed to parse JSON: {}", err),
72                    };
73
74                    (
75                        StatusCode::BAD_REQUEST,
76                        Json(json!({
77                            "error": {
78                                "message": error_message,
79                                "type": "invalid_request_error",
80                                "code": "json_parse_error"
81                            }
82                        })),
83                    )
84                        .into_response()
85                })?;
86
87        // Normalize the request (apply defaults based on other fields)
88        data.normalize();
89
90        // Then, automatically validate the data
91        data.validate().map_err(|validation_errors| {
92            (
93                StatusCode::BAD_REQUEST,
94                Json(json!({
95                    "error": {
96                        "message": validation_errors.to_string(),
97                        "type": "invalid_request_error",
98                        "code": 400
99                    }
100                })),
101            )
102                .into_response()
103        })?;
104
105        Ok(ValidatedJson(data))
106    }
107}
108
109// Implement Deref to allow transparent access to the inner value
110#[cfg(feature = "axum")]
111impl<T> std::ops::Deref for ValidatedJson<T> {
112    type Target = T;
113
114    fn deref(&self) -> &Self::Target {
115        &self.0
116    }
117}
118
119#[cfg(feature = "axum")]
120impl<T> std::ops::DerefMut for ValidatedJson<T> {
121    fn deref_mut(&mut self) -> &mut Self::Target {
122        &mut self.0
123    }
124}
125
126#[cfg(all(test, feature = "axum"))]
127mod tests {
128    use serde::{Deserialize, Serialize};
129    use validator::Validate;
130
131    use super::*;
132
133    #[derive(Debug, Deserialize, Serialize, Validate)]
134    struct TestRequest {
135        #[validate(range(min = 0.0, max = 1.0))]
136        value: f32,
137        #[validate(length(min = 1))]
138        name: String,
139    }
140
141    impl Normalizable for TestRequest {
142        // Use default no-op implementation
143    }
144
145    #[tokio::test]
146    async fn test_validated_json_valid() {
147        // This test is conceptual - actual testing would require Axum test harness
148        let request = TestRequest {
149            value: 0.5,
150            name: "test".to_string(),
151        };
152        assert!(request.validate().is_ok());
153    }
154
155    #[tokio::test]
156    async fn test_validated_json_invalid_range() {
157        let request = TestRequest {
158            value: 1.5, // Out of range
159            name: "test".to_string(),
160        };
161        assert!(request.validate().is_err());
162    }
163
164    #[tokio::test]
165    async fn test_validated_json_invalid_length() {
166        let request = TestRequest {
167            value: 0.5,
168            name: "".to_string(), // Empty name
169        };
170        assert!(request.validate().is_err());
171    }
172}