openai_protocol/
validated.rs1pub trait Normalizable {
8 fn normalize(&mut self) {
10 }
12}
13
14#[cfg(feature = "axum")]
15use axum::{
16 extract::{rejection::JsonRejection, FromRequest, Request},
17 http::StatusCode,
18 response::{IntoResponse, Response},
19 Json,
20};
21#[cfg(feature = "axum")]
22use serde::de::DeserializeOwned;
23#[cfg(feature = "axum")]
24use serde_json::json;
25#[cfg(feature = "axum")]
26use validator::Validate;
27
28#[cfg(feature = "axum")]
45pub struct ValidatedJson<T>(pub T);
46
47#[cfg(feature = "axum")]
48impl<S, T> FromRequest<S> for ValidatedJson<T>
49where
50 T: DeserializeOwned + Validate + Normalizable + Send,
51 S: Send + Sync,
52{
53 type Rejection = Response;
54
55 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
56 let Json(mut data) =
58 Json::<T>::from_request(req, state)
59 .await
60 .map_err(|err: JsonRejection| {
61 let error_message = match err {
62 JsonRejection::JsonDataError(e) => {
63 format!("Invalid JSON data: {}", e)
64 }
65 JsonRejection::JsonSyntaxError(e) => {
66 format!("JSON syntax error: {}", e)
67 }
68 JsonRejection::MissingJsonContentType(_) => {
69 "Missing Content-Type: application/json header".to_string()
70 }
71 _ => format!("Failed to parse JSON: {}", err),
72 };
73
74 (
75 StatusCode::BAD_REQUEST,
76 Json(json!({
77 "error": {
78 "message": error_message,
79 "type": "invalid_request_error",
80 "code": "json_parse_error"
81 }
82 })),
83 )
84 .into_response()
85 })?;
86
87 data.normalize();
89
90 data.validate().map_err(|validation_errors| {
92 (
93 StatusCode::BAD_REQUEST,
94 Json(json!({
95 "error": {
96 "message": validation_errors.to_string(),
97 "type": "invalid_request_error",
98 "code": 400
99 }
100 })),
101 )
102 .into_response()
103 })?;
104
105 Ok(ValidatedJson(data))
106 }
107}
108
109#[cfg(feature = "axum")]
111impl<T> std::ops::Deref for ValidatedJson<T> {
112 type Target = T;
113
114 fn deref(&self) -> &Self::Target {
115 &self.0
116 }
117}
118
119#[cfg(feature = "axum")]
120impl<T> std::ops::DerefMut for ValidatedJson<T> {
121 fn deref_mut(&mut self) -> &mut Self::Target {
122 &mut self.0
123 }
124}
125
126#[cfg(all(test, feature = "axum"))]
127mod tests {
128 use serde::{Deserialize, Serialize};
129 use validator::Validate;
130
131 use super::*;
132
133 #[derive(Debug, Deserialize, Serialize, Validate)]
134 struct TestRequest {
135 #[validate(range(min = 0.0, max = 1.0))]
136 value: f32,
137 #[validate(length(min = 1))]
138 name: String,
139 }
140
141 impl Normalizable for TestRequest {
142 }
144
145 #[tokio::test]
146 async fn test_validated_json_valid() {
147 let request = TestRequest {
149 value: 0.5,
150 name: "test".to_string(),
151 };
152 assert!(request.validate().is_ok());
153 }
154
155 #[tokio::test]
156 async fn test_validated_json_invalid_range() {
157 let request = TestRequest {
158 value: 1.5, name: "test".to_string(),
160 };
161 assert!(request.validate().is_err());
162 }
163
164 #[tokio::test]
165 async fn test_validated_json_invalid_length() {
166 let request = TestRequest {
167 value: 0.5,
168 name: "".to_string(), };
170 assert!(request.validate().is_err());
171 }
172}