1use std::collections::HashMap;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures_util::Stream;
6use pin_project_lite::pin_project;
7use serde::{Deserialize, Serialize};
8
9use crate::client::Client;
10use crate::error::Result;
11
12#[derive(Debug, Clone, Serialize, Default)]
14pub struct ChatRequest {
15 pub model: String,
17
18 pub messages: Vec<ChatMessage>,
20
21 #[serde(skip_serializing_if = "Option::is_none")]
23 pub tools: Option<Vec<ChatTool>>,
24
25 #[serde(skip_serializing_if = "Option::is_none")]
27 pub stream: Option<bool>,
28
29 #[serde(skip_serializing_if = "Option::is_none")]
31 pub temperature: Option<f64>,
32
33 #[serde(skip_serializing_if = "Option::is_none")]
35 pub max_tokens: Option<i32>,
36
37 #[serde(skip_serializing_if = "Option::is_none")]
39 pub provider_options: Option<HashMap<String, serde_json::Value>>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize, Default)]
44pub struct ChatMessage {
45 pub role: String,
47
48 #[serde(skip_serializing_if = "Option::is_none")]
50 pub content: Option<String>,
51
52 #[serde(skip_serializing_if = "Option::is_none")]
55 pub content_blocks: Option<Vec<ContentBlock>>,
56
57 #[serde(skip_serializing_if = "Option::is_none")]
59 pub tool_call_id: Option<String>,
60
61 #[serde(skip_serializing_if = "Option::is_none")]
63 pub is_error: Option<bool>,
64}
65
66impl ChatMessage {
67 pub fn user(content: impl Into<String>) -> Self {
69 Self {
70 role: "user".to_string(),
71 content: Some(content.into()),
72 ..Default::default()
73 }
74 }
75
76 pub fn assistant(content: impl Into<String>) -> Self {
78 Self {
79 role: "assistant".to_string(),
80 content: Some(content.into()),
81 ..Default::default()
82 }
83 }
84
85 pub fn system(content: impl Into<String>) -> Self {
87 Self {
88 role: "system".to_string(),
89 content: Some(content.into()),
90 ..Default::default()
91 }
92 }
93
94 pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
96 Self {
97 role: "tool".to_string(),
98 content: Some(content.into()),
99 tool_call_id: Some(tool_call_id.into()),
100 ..Default::default()
101 }
102 }
103
104 pub fn tool_error(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
106 Self {
107 role: "tool".to_string(),
108 content: Some(content.into()),
109 tool_call_id: Some(tool_call_id.into()),
110 is_error: Some(true),
111 ..Default::default()
112 }
113 }
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize, Default)]
118pub struct ContentBlock {
119 #[serde(rename = "type")]
121 pub block_type: String,
122
123 #[serde(skip_serializing_if = "Option::is_none")]
125 pub text: Option<String>,
126
127 #[serde(skip_serializing_if = "Option::is_none")]
129 pub id: Option<String>,
130
131 #[serde(skip_serializing_if = "Option::is_none")]
133 pub name: Option<String>,
134
135 #[serde(skip_serializing_if = "Option::is_none")]
137 pub input: Option<HashMap<String, serde_json::Value>>,
138}
139
140#[derive(Debug, Clone, Serialize, Default)]
142pub struct ChatTool {
143 pub name: String,
145
146 pub description: String,
148
149 #[serde(skip_serializing_if = "Option::is_none")]
151 pub parameters: Option<serde_json::Value>,
152}
153
154#[derive(Debug, Clone, Deserialize)]
156pub struct ChatResponse {
157 pub id: String,
159
160 pub model: String,
162
163 #[serde(default)]
165 pub content: Vec<ContentBlock>,
166
167 pub usage: Option<ChatUsage>,
169
170 #[serde(default)]
172 pub stop_reason: String,
173
174 #[serde(skip)]
176 pub cost_ticks: i64,
177
178 #[serde(skip)]
180 pub request_id: String,
181}
182
183impl ChatResponse {
184 pub fn text(&self) -> String {
186 self.content
187 .iter()
188 .filter(|b| b.block_type == "text")
189 .filter_map(|b| b.text.as_deref())
190 .collect::<Vec<_>>()
191 .join("")
192 }
193
194 pub fn thinking(&self) -> String {
196 self.content
197 .iter()
198 .filter(|b| b.block_type == "thinking")
199 .filter_map(|b| b.text.as_deref())
200 .collect::<Vec<_>>()
201 .join("")
202 }
203
204 pub fn tool_calls(&self) -> Vec<&ContentBlock> {
206 self.content
207 .iter()
208 .filter(|b| b.block_type == "tool_use")
209 .collect()
210 }
211}
212
213#[derive(Debug, Clone, Deserialize)]
215pub struct ChatUsage {
216 pub input_tokens: i32,
217 pub output_tokens: i32,
218 pub cost_ticks: i64,
219}
220
221#[derive(Debug, Clone)]
223pub struct StreamEvent {
224 pub event_type: String,
226
227 pub delta: Option<StreamDelta>,
229
230 pub tool_use: Option<StreamToolUse>,
232
233 pub usage: Option<ChatUsage>,
235
236 pub error: Option<String>,
238
239 pub done: bool,
241}
242
243#[derive(Debug, Clone, Deserialize)]
245pub struct StreamDelta {
246 pub text: String,
247}
248
249#[derive(Debug, Clone, Deserialize)]
251pub struct StreamToolUse {
252 pub id: String,
253 pub name: String,
254 pub input: HashMap<String, serde_json::Value>,
255}
256
257#[derive(Deserialize)]
259struct RawStreamEvent {
260 #[serde(rename = "type")]
261 event_type: String,
262 #[serde(default)]
263 delta: Option<StreamDelta>,
264 #[serde(default)]
265 id: Option<String>,
266 #[serde(default)]
267 name: Option<String>,
268 #[serde(default)]
269 input: Option<HashMap<String, serde_json::Value>>,
270 #[serde(default)]
271 input_tokens: Option<i32>,
272 #[serde(default)]
273 output_tokens: Option<i32>,
274 #[serde(default)]
275 cost_ticks: Option<i64>,
276 #[serde(default)]
277 message: Option<String>,
278}
279
280pin_project! {
281 pub struct ChatStream {
283 #[pin]
284 inner: Pin<Box<dyn Stream<Item = StreamEvent> + Send>>,
285 }
286}
287
288impl Stream for ChatStream {
289 type Item = StreamEvent;
290
291 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
292 self.project().inner.poll_next(cx)
293 }
294}
295
296impl Client {
297 pub async fn chat(&self, req: &ChatRequest) -> Result<ChatResponse> {
299 let mut req = req.clone();
300 req.stream = Some(false);
301
302 let (mut resp, meta) = self.post_json::<ChatRequest, ChatResponse>("/qai/v1/chat", &req).await?;
303 resp.cost_ticks = meta.cost_ticks;
304 resp.request_id = meta.request_id;
305 if resp.model.is_empty() {
306 resp.model = meta.model;
307 }
308 Ok(resp)
309 }
310
311 pub async fn chat_stream(&self, req: &ChatRequest) -> Result<ChatStream> {
335 let mut req = req.clone();
336 req.stream = Some(true);
337
338 let (resp, _meta) = self.post_stream_raw("/qai/v1/chat", &req).await?;
339
340 let byte_stream = resp.bytes_stream();
341 let event_stream = sse_to_events(byte_stream);
342
343 Ok(ChatStream {
344 inner: Box::pin(event_stream),
345 })
346 }
347}
348
349fn sse_to_events<S>(byte_stream: S) -> impl Stream<Item = StreamEvent> + Send
351where
352 S: Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
353{
354 let pinned_stream = Box::pin(byte_stream);
356
357 let line_stream = futures_util::stream::unfold(
359 (pinned_stream, String::new()),
360 |(mut stream, mut buffer)| async move {
361 use futures_util::StreamExt;
362 loop {
363 if let Some(newline_pos) = buffer.find('\n') {
365 let line = buffer[..newline_pos].trim_end_matches('\r').to_string();
366 buffer = buffer[newline_pos + 1..].to_string();
367 return Some((line, (stream, buffer)));
368 }
369
370 match stream.next().await {
372 Some(Ok(chunk)) => {
373 buffer.push_str(&String::from_utf8_lossy(&chunk));
374 }
375 Some(Err(_)) | None => {
376 if !buffer.is_empty() {
378 let remaining = std::mem::take(&mut buffer);
379 return Some((remaining, (stream, buffer)));
380 }
381 return None;
382 }
383 }
384 }
385 },
386 );
387
388 let pinned_lines = Box::pin(line_stream);
389 futures_util::stream::unfold(pinned_lines, |mut lines| async move {
390 use futures_util::StreamExt;
391 loop {
392 let line = lines.next().await?;
393
394 if !line.starts_with("data: ") {
395 continue;
396 }
397 let payload = &line["data: ".len()..];
398
399 if payload == "[DONE]" {
400 let ev = StreamEvent {
401 event_type: "done".to_string(),
402 delta: None,
403 tool_use: None,
404 usage: None,
405 error: None,
406 done: true,
407 };
408 return Some((ev, lines));
409 }
410
411 let raw: RawStreamEvent = match serde_json::from_str(payload) {
412 Ok(r) => r,
413 Err(e) => {
414 let ev = StreamEvent {
415 event_type: "error".to_string(),
416 delta: None,
417 tool_use: None,
418 usage: None,
419 error: Some(format!("parse SSE: {e}")),
420 done: false,
421 };
422 return Some((ev, lines));
423 }
424 };
425
426 let mut ev = StreamEvent {
427 event_type: raw.event_type.clone(),
428 delta: None,
429 tool_use: None,
430 usage: None,
431 error: None,
432 done: false,
433 };
434
435 match raw.event_type.as_str() {
436 "content_delta" | "thinking_delta" => {
437 ev.delta = raw.delta;
438 }
439 "tool_use" => {
440 ev.tool_use = Some(StreamToolUse {
441 id: raw.id.unwrap_or_default(),
442 name: raw.name.unwrap_or_default(),
443 input: raw.input.unwrap_or_default(),
444 });
445 }
446 "usage" => {
447 ev.usage = Some(ChatUsage {
448 input_tokens: raw.input_tokens.unwrap_or(0),
449 output_tokens: raw.output_tokens.unwrap_or(0),
450 cost_ticks: raw.cost_ticks.unwrap_or(0),
451 });
452 }
453 "error" => {
454 ev.error = raw.message;
455 }
456 "heartbeat" => {}
457 _ => {}
458 }
459
460 return Some((ev, lines));
461 }
462 })
463}