1use crate::config::OpenAIClient;
10use crate::error::OpenAIError;
11use serde::Serialize;
12use serde::de::DeserializeOwned;
13
14use futures_util::stream::TryStreamExt;
16use tokio::io::{AsyncBufReadExt, BufReader};
17use tokio_stream::Stream; use tokio_stream::StreamExt as TokioStreamExt; use tokio_stream::wrappers::LinesStream;
20use tokio_util::io::StreamReader;
21
22pub(crate) async fn post_json<T, R>(
41 client: &OpenAIClient,
42 endpoint: &str,
43 body: &T,
44) -> Result<R, OpenAIError>
45where
46 T: Serialize,
47 R: DeserializeOwned,
48{
49 let url = format!("{}/{}", client.base_url().trim_end_matches('/'), endpoint);
50 let mut request_builder = client.http_client.post(&url).bearer_auth(client.api_key());
51
52 if let Some(org_id) = client.organization() {
54 request_builder = request_builder.header("OpenAI-Organization", org_id);
55 }
56
57 let response = request_builder.json(body).send().await?;
58
59 handle_response(response).await
60}
61
62pub(crate) async fn get_json<R>(client: &OpenAIClient, endpoint: &str) -> Result<R, OpenAIError>
80where
81 R: DeserializeOwned,
82{
83 let url = format!("{}/{}", client.base_url().trim_end_matches('/'), endpoint);
84 let mut request_builder = client.http_client.get(&url).bearer_auth(client.api_key());
85
86 if let Some(org_id) = client.organization() {
88 request_builder = request_builder.header("OpenAI-Organization", org_id);
89 }
90
91 let response = request_builder.send().await?;
92
93 handle_response(response).await
94}
95
96async fn handle_response<R>(response: reqwest::Response) -> Result<R, OpenAIError>
110where
111 R: DeserializeOwned,
112{
113 let status = response.status();
114 if status.is_success() {
115 let text = response.text().await?;
117
118 let parsed: R = serde_json::from_str(&text).map_err(OpenAIError::from)?;
120
121 Ok(parsed)
122 } else {
123 parse_error_response(response).await
124 }
125}
126
127pub async fn parse_error_response<R>(response: reqwest::Response) -> Result<R, OpenAIError> {
130 let status = response.status();
131 let text_body = response.text().await.unwrap_or_else(|_| "".to_string());
132
133 match serde_json::from_str::<crate::error::OpenAIAPIErrorBody>(&text_body) {
134 Ok(body) => Err(OpenAIError::from(body)),
135 Err(_) => {
136 let msg = format!(
137 "HTTP {} returned from OpenAI API; body: {}",
138 status, text_body
139 );
140 Err(OpenAIError::APIError {
141 message: msg,
142 err_type: None,
143 code: None,
144 })
145 }
146 }
147}
148
149pub async fn post_json_stream<T, R>(
170 client: &OpenAIClient,
171 endpoint: &str,
172 body: &T,
173) -> Result<impl Stream<Item = Result<R, OpenAIError>>, OpenAIError>
174where
175 T: Serialize,
176 R: DeserializeOwned + 'static,
177{
178 let url = format!("{}/{}", client.base_url().trim_end_matches('/'), endpoint);
179 let mut request_builder = client.http_client.post(&url).bearer_auth(client.api_key());
180
181 if let Some(org_id) = client.organization() {
182 request_builder = request_builder.header("OpenAI-Organization", org_id);
183 }
184
185 let response = request_builder.json(body).send().await?;
186
187 let status = response.status();
188 if !status.is_success() {
189 let text_body = response.text().await.unwrap_or_else(|_| "".to_string());
190 match serde_json::from_str::<crate::error::OpenAIAPIErrorBody>(&text_body) {
191 Ok(body_err) => return Err(OpenAIError::from(body_err)),
192 Err(_) => {
193 return Err(OpenAIError::APIError {
194 message: format!(
195 "HTTP {} returned from OpenAI API; body: {}",
196 status, text_body
197 ),
198 err_type: None,
199 code: None,
200 });
201 }
202 }
203 }
204
205 let byte_stream = response
207 .bytes_stream()
208 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
209 let stream_reader = StreamReader::new(byte_stream);
210 let buf_reader = BufReader::new(stream_reader);
211
212 let lines = LinesStream::new(buf_reader.lines());
214
215 let stream = lines.filter_map(|line_result| {
220 match line_result {
221 Ok(line) => {
222 let trimmed = line.trim();
223 if trimmed.is_empty() || trimmed.contains("[DONE]") {
225 None
226 } else {
227 let data = if trimmed.starts_with("data:") {
229 trimmed.trim_start_matches("data:").trim()
230 } else {
231 trimmed
232 };
233 match serde_json::from_str::<R>(data) {
235 Ok(parsed) => Some(Ok(parsed)),
236 Err(e) => {
237 eprintln!(
238 "Warning: failed to deserialize chunk: {:?} (error: {})",
239 data, e
240 );
241 None }
243 }
244 }
245 }
246 Err(e) => Some(Err(OpenAIError::from(e))),
247 }
248 });
249 Ok(stream)
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
260 use crate::config::OpenAIClient;
261 use crate::error::{OpenAIError, OpenAIError::APIError};
262 use serde::Deserialize;
263 use tokio_stream::StreamExt;
264 use wiremock::matchers::{method, path};
265 use wiremock::{Mock, MockServer, ResponseTemplate}; #[derive(Debug, Deserialize)]
268 struct MockResponse {
269 pub foo: String,
270 pub bar: i32,
271 }
272
273 #[tokio::test]
275 async fn test_post_json_success() {
276 let mock_server = MockServer::start().await;
278
279 let mock_data = serde_json::json!({ "foo": "hello", "bar": 42 });
281
282 Mock::given(method("POST"))
284 .and(path("/test-endpoint"))
285 .respond_with(ResponseTemplate::new(200).set_body_json(mock_data))
286 .mount(&mock_server)
287 .await;
288
289 let client = OpenAIClient::builder()
291 .with_api_key("test-key")
292 .with_base_url(&mock_server.uri())
293 .build()
294 .unwrap();
295
296 let request_body = serde_json::json!({ "dummy": true });
298
299 let result: Result<MockResponse, OpenAIError> =
301 post_json(&client, "test-endpoint", &request_body).await;
302
303 assert!(result.is_ok(), "Expected Ok, got Err");
305 let parsed = result.unwrap();
306 assert_eq!(parsed.foo, "hello");
307 assert_eq!(parsed.bar, 42);
308 }
309
310 #[tokio::test]
312 async fn test_post_json_api_error() {
313 let mock_server = MockServer::start().await;
314
315 let error_body = serde_json::json!({
317 "error": {
318 "message": "Invalid request",
319 "type": "invalid_request_error",
320 "param": null,
321 "code": "some_code"
322 }
323 });
324
325 Mock::given(method("POST"))
326 .and(path("/test-endpoint"))
327 .respond_with(ResponseTemplate::new(400).set_body_json(error_body))
328 .mount(&mock_server)
329 .await;
330
331 let client = OpenAIClient::builder()
332 .with_api_key("test-key")
333 .with_base_url(&mock_server.uri())
334 .build()
335 .unwrap();
336
337 let request_body = serde_json::json!({ "dummy": true });
338
339 let result: Result<MockResponse, OpenAIError> =
340 post_json(&client, "test-endpoint", &request_body).await;
341
342 match result {
344 Err(APIError { message, .. }) => {
345 assert!(
346 message.contains("Invalid request"),
347 "Expected error message about invalid request, got: {}",
348 message
349 );
350 }
351 other => panic!("Expected APIError, got {:?}", other),
352 }
353 }
354
355 #[tokio::test]
357 async fn test_post_json_deserialize_error() {
358 let mock_server = MockServer::start().await;
359
360 let invalid_json = r#"{"foo": 123, "bar": "not_an_integer"}"#;
362
363 Mock::given(method("POST"))
364 .and(path("/test-endpoint"))
365 .respond_with(ResponseTemplate::new(200).set_body_raw(invalid_json, "application/json"))
366 .mount(&mock_server)
367 .await;
368
369 let client = OpenAIClient::builder()
370 .with_api_key("test-key")
371 .with_base_url(&mock_server.uri())
372 .build()
373 .unwrap();
374
375 let request_body = serde_json::json!({ "dummy": true });
376
377 let result: Result<MockResponse, OpenAIError> =
378 post_json(&client, "test-endpoint", &request_body).await;
379
380 assert!(matches!(result, Err(OpenAIError::DeserializeError(_))));
382 }
383
384 #[tokio::test]
386 async fn test_get_json_success() {
387 let mock_server = MockServer::start().await;
388
389 let mock_data = serde_json::json!({ "foo": "abc", "bar": 99 });
390
391 Mock::given(method("GET"))
393 .and(path("/test-get"))
394 .respond_with(ResponseTemplate::new(200).set_body_json(mock_data))
395 .mount(&mock_server)
396 .await;
397
398 let client = OpenAIClient::builder()
399 .with_api_key("test-key")
400 .with_base_url(&mock_server.uri())
401 .build()
402 .unwrap();
403
404 let result: Result<MockResponse, OpenAIError> = get_json(&client, "test-get").await;
406
407 assert!(result.is_ok());
409 let parsed = result.unwrap();
410 assert_eq!(parsed.foo, "abc");
411 assert_eq!(parsed.bar, 99);
412 }
413
414 #[tokio::test]
416 async fn test_get_json_api_error() {
417 let mock_server = MockServer::start().await;
418
419 let error_body = serde_json::json!({
420 "error": {
421 "message": "Resource not found",
422 "type": "not_found",
423 "code": "missing_resource"
424 }
425 });
426
427 Mock::given(method("GET"))
428 .and(path("/test-get"))
429 .respond_with(ResponseTemplate::new(404).set_body_json(error_body))
430 .mount(&mock_server)
431 .await;
432
433 let client = OpenAIClient::builder()
434 .with_api_key("test-key")
435 .with_base_url(&mock_server.uri())
436 .build()
437 .unwrap();
438
439 let result: Result<MockResponse, OpenAIError> = get_json(&client, "test-get").await;
440
441 match result {
442 Err(APIError { message, .. }) => {
443 assert!(message.contains("Resource not found"));
444 }
445 other => panic!("Expected APIError, got {:?}", other),
446 }
447 }
448
449 #[tokio::test]
453 async fn test_post_json_stream_success() {
454 let mock_server = MockServer::start().await;
455
456 let body = r#"data: {"foo":"first","bar":1}
458data: {"foo":"second","bar":2}
459data: [DONE]
460"#;
461
462 Mock::given(method("POST"))
463 .and(path("/stream-endpoint"))
464 .respond_with(ResponseTemplate::new(200).set_body_raw(body, "text/event-stream"))
465 .mount(&mock_server)
466 .await;
467
468 let client = OpenAIClient::builder()
469 .with_api_key("test-key")
470 .with_base_url(&mock_server.uri())
471 .build()
472 .unwrap();
473
474 let req_body = serde_json::json!({ "stream": true });
476
477 let stream = post_json_stream::<serde_json::Value, MockResponse>(
478 &client,
479 "stream-endpoint",
480 &req_body,
481 )
482 .await
483 .expect("stream should start OK");
484
485 let items: Vec<MockResponse> = stream
487 .map(|r| r.expect("chunk should be Ok"))
488 .collect()
489 .await;
490
491 assert_eq!(items.len(), 2);
492 assert_eq!(items[0].foo, "first");
493 assert_eq!(items[0].bar, 1);
494 assert_eq!(items[1].foo, "second");
495 assert_eq!(items[1].bar, 2);
496 }
497
498 #[tokio::test]
502 async fn test_post_json_stream_api_error() {
503 let mock_server = MockServer::start().await;
504
505 let error_body = serde_json::json!({
506 "error": {
507 "message": "Streaming not allowed",
508 "type": "invalid_request_error",
509 "code": "not_allowed"
510 }
511 });
512
513 Mock::given(method("POST"))
514 .and(path("/stream-endpoint"))
515 .respond_with(ResponseTemplate::new(403).set_body_json(error_body))
516 .mount(&mock_server)
517 .await;
518
519 let client = OpenAIClient::builder()
520 .with_api_key("test-key")
521 .with_base_url(&mock_server.uri())
522 .build()
523 .unwrap();
524
525 let req_body = serde_json::json!({ "stream": true });
526
527 let result = post_json_stream::<serde_json::Value, MockResponse>(
528 &client,
529 "stream-endpoint",
530 &req_body,
531 )
532 .await;
533
534 match result {
535 Err(OpenAIError::APIError { message, .. }) => {
536 assert!(message.contains("Streaming not allowed"));
537 }
538 Ok(_) => panic!("Expected APIError, got Ok(_)"),
539 Err(other) => panic!("Expected APIError, got {:?}", other.to_string()),
540 }
541 }
542
543 #[tokio::test]
548 async fn test_post_json_plain_text_error_fallback() {
549 let mock_server = MockServer::start().await;
550
551 Mock::given(method("POST"))
552 .and(path("/plain-error"))
553 .respond_with(
554 ResponseTemplate::new(503).set_body_raw("Service unavailable", "text/plain"),
555 )
556 .mount(&mock_server)
557 .await;
558
559 let client = OpenAIClient::builder()
560 .with_api_key("test-key")
561 .with_base_url(&mock_server.uri())
562 .build()
563 .unwrap();
564
565 let req_body = serde_json::json!({});
566
567 let result: Result<MockResponse, OpenAIError> =
568 post_json(&client, "plain-error", &req_body).await;
569
570 match result {
571 Err(OpenAIError::APIError { message, .. }) => {
572 assert!(
573 message.contains("Service unavailable"),
574 "Unexpected message: {message}"
575 );
576 }
577 other => panic!("Expected generic APIError, got {:?}", other),
578 }
579 }
580}