open_routerer/api/
chat.rs1#![allow(dead_code)]
2
3use std::pin::Pin;
4
5use crate::{
6 client::{Client, ClientConfig},
7 error::{Error, Result},
8 types::chat::{ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse},
9};
10use async_stream::try_stream;
11use futures::{Stream, StreamExt, TryStreamExt};
12use reqwest::Client as ReqwestClient;
13use tokio_util::{
14 codec::{FramedRead, LinesCodec},
15 io::StreamReader,
16};
17
18use super::request::RequestPayload;
19
20pub struct ChatApi {
21 pub(crate) http_client: ReqwestClient,
22 pub(crate) config: ClientConfig,
23}
24
25impl ChatApi {
26 pub fn new(http_client: ReqwestClient, config: &ClientConfig) -> Self {
27 Self {
28 http_client,
29 config: config.clone(),
30 }
31 }
32
33 pub async fn completion(
34 &self,
35 request: ChatCompletionRequest,
36 ) -> Result<ChatCompletionResponse> {
37 let client = self.http_client.clone();
38 let config = self.config.clone();
39 let url = config
41 .base_url
42 .join("chat/completions")
43 .map_err(|e| Error::ApiError {
44 code: 400,
45 message: format!("URL join error: {}", e),
46 metadata: None,
47 })?;
48
49 let response = client
50 .post(url)
51 .headers(config.build_headers()?)
52 .json(&request)
53 .send()
54 .await?;
55
56 let chat_response = Client::handle_response(response).await?;
57
58 Ok(chat_response)
59 }
60
61 pub fn completion_stream(
62 &self,
63 request: ChatCompletionRequest,
64 ) -> Result<Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>> {
65 let config = self.config.clone();
66 let client = self.http_client.clone();
67
68 let stream = try_stream! {
69 let url = config.base_url.join("chat/completions").map_err(|e| Error::ApiError {
71 code: 400,
72 message: format!("Invalid URL: {}", e),
73 metadata: None,
74 })?;
75
76 let mut req_body = serde_jsonc2::to_value(&request).map_err(|e| Error::ApiError {
78 code: 500,
79 message: format!("Request serialization error: {}", e),
80 metadata: None,
81 })?;
82 req_body["stream"] = serde_jsonc2::Value::Bool(true);
84
85 let response = client
88 .post(url)
89 .headers(config.build_headers()?)
90 .json(&req_body)
91 .send()
92 .await?
93 .error_for_status()
94 .map_err(|e| {
95 Error::ApiError {
97 code: e.status().map(|s| s.as_u16()).unwrap_or(500),
98 message: e.to_string(),
99 metadata: None,
100 }
101 })?;
102
103 let byte_stream = response.bytes_stream()
106 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
107 let stream_reader = StreamReader::new(byte_stream);
108 let mut lines = FramedRead::new(stream_reader, LinesCodec::new());
109
110 while let Some(line_result) = lines.next().await {
112 let line = line_result.map_err(|e| Error::ApiError {
114 code: 500,
115 message: format!("LinesCodec error: {}", e),
116 metadata: None,
117 })?;
118 if line.trim().is_empty() {
119 continue;
120 }
121 if line.starts_with("data:") {
122 let data_part = line.trim_start_matches("data:").trim();
123 if data_part == "[DONE]" {
124 break;
125 }
126 match serde_jsonc2::from_str::<ChatCompletionChunk>(data_part) {
127 Ok(chunk) => yield chunk,
128 Err(_err) => continue,
129 }
130 } else if line.starts_with(":") {
131 continue;
133 }
134 }
135 };
136
137 Ok(Box::pin(stream))
138 }
139}
140
141impl From<RequestPayload> for ChatCompletionRequest {
142 fn from(value: RequestPayload) -> Self {
143 Self {
144 model: value.model,
145 messages: value.messages,
146 ..Default::default()
147 }
148 }
149}