open_routerer/api/
chat.rs

1#![allow(dead_code)]
2
3use std::pin::Pin;
4
5use crate::{
6    client::{Client, ClientConfig},
7    error::{Error, Result},
8    types::chat::{ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse},
9};
10use async_stream::try_stream;
11use futures::{Stream, StreamExt, TryStreamExt};
12use reqwest::Client as ReqwestClient;
13use tokio_util::{
14    codec::{FramedRead, LinesCodec},
15    io::StreamReader,
16};
17
18use super::request::RequestPayload;
19
20pub struct ChatApi {
21    pub(crate) http_client: ReqwestClient,
22    pub(crate) config: ClientConfig,
23}
24
25impl ChatApi {
26    pub fn new(http_client: ReqwestClient, config: &ClientConfig) -> Self {
27        Self {
28            http_client,
29            config: config.clone(),
30        }
31    }
32
33    pub async fn completion(
34        &self,
35        request: ChatCompletionRequest,
36    ) -> Result<ChatCompletionResponse> {
37        let client = self.http_client.clone();
38        let config = self.config.clone();
39        // Build the full URL by joining relative path.
40        let url = config
41            .base_url
42            .join("chat/completions")
43            .map_err(|e| Error::ApiError {
44                code: 400,
45                message: format!("URL join error: {}", e),
46                metadata: None,
47            })?;
48
49        let response = client
50            .post(url)
51            .headers(config.build_headers()?)
52            .json(&request)
53            .send()
54            .await?;
55
56        let chat_response = Client::handle_response(response).await?;
57
58        Ok(chat_response)
59    }
60
61    pub fn completion_stream(
62        &self,
63        request: ChatCompletionRequest,
64    ) -> Result<Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>> {
65        let config = self.config.clone();
66        let client = self.http_client.clone();
67
68        let stream = try_stream! {
69            // Build the complete URL for the chat completions endpoint.
70            let url = config.base_url.join("chat/completions").map_err(|e| Error::ApiError {
71                code: 400,
72                message: format!("Invalid URL: {}", e),
73                metadata: None,
74            })?;
75
76            // Serialize the request into a JSON value.
77            let mut req_body = serde_jsonc2::to_value(&request).map_err(|e| Error::ApiError {
78                code: 500,
79                message: format!("Request serialization error: {}", e),
80                metadata: None,
81            })?;
82            // Ensure streaming is enabled.
83            req_body["stream"] = serde_jsonc2::Value::Bool(true);
84
85            // Issue the POST request with the appropriate headers and JSON body.
86            // Use error_for_status() to perform status checking without consuming the response twice.
87            let response = client
88                .post(url)
89                .headers(config.build_headers()?)
90                .json(&req_body)
91                .send()
92                .await?
93                .error_for_status()
94                .map_err(|e| {
95                    // Map the reqwest error into our custom Error.
96                    Error::ApiError {
97                        code: e.status().map(|s| s.as_u16()).unwrap_or(500),
98                        message: e.to_string(),
99                        metadata: None,
100                    }
101                })?;
102
103            // Convert the response bytes stream into an asynchronous line stream.
104            // At this point, the response is known to be successful.
105            let byte_stream = response.bytes_stream()
106                .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
107            let stream_reader = StreamReader::new(byte_stream);
108            let mut lines = FramedRead::new(stream_reader, LinesCodec::new());
109
110            // Process each SSE line.
111            while let Some(line_result) = lines.next().await {
112                // Map any LinesCodec error into our API error.
113                let line = line_result.map_err(|e| Error::ApiError {
114                    code: 500,
115                    message: format!("LinesCodec error: {}", e),
116                    metadata: None,
117                })?;
118                if line.trim().is_empty() {
119                    continue;
120                }
121                if line.starts_with("data:") {
122                    let data_part = line.trim_start_matches("data:").trim();
123                    if data_part == "[DONE]" {
124                        break;
125                    }
126                    match serde_jsonc2::from_str::<ChatCompletionChunk>(data_part) {
127                        Ok(chunk) => yield chunk,
128                        Err(_err) => continue,
129                    }
130                } else if line.starts_with(":") {
131                    // SSE comment; ignore the line.
132                    continue;
133                }
134            }
135        };
136
137        Ok(Box::pin(stream))
138    }
139}
140
141impl From<RequestPayload> for ChatCompletionRequest {
142    fn from(value: RequestPayload) -> Self {
143        Self {
144            model: value.model,
145            messages: value.messages,
146            ..Default::default()
147        }
148    }
149}