1use std::collections::HashMap;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures_util::Stream;
6use pin_project_lite::pin_project;
7use serde::{Deserialize, Serialize};
8
9use crate::client::Client;
10use crate::error::Result;
11
12#[derive(Debug, Clone, Serialize, Default)]
14pub struct ChatRequest {
15 pub model: String,
17
18 pub messages: Vec<ChatMessage>,
20
21 #[serde(skip_serializing_if = "Option::is_none")]
23 pub tools: Option<Vec<ChatTool>>,
24
25 #[serde(skip_serializing_if = "Option::is_none")]
27 pub stream: Option<bool>,
28
29 #[serde(skip_serializing_if = "Option::is_none")]
31 pub temperature: Option<f64>,
32
33 #[serde(skip_serializing_if = "Option::is_none")]
35 pub max_tokens: Option<i32>,
36
37 #[serde(skip_serializing_if = "Option::is_none")]
39 pub provider_options: Option<HashMap<String, serde_json::Value>>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize, Default)]
44pub struct ChatMessage {
45 pub role: String,
47
48 #[serde(skip_serializing_if = "Option::is_none")]
50 pub content: Option<String>,
51
52 #[serde(skip_serializing_if = "Option::is_none")]
55 pub content_blocks: Option<Vec<ContentBlock>>,
56
57 #[serde(skip_serializing_if = "Option::is_none")]
59 pub tool_call_id: Option<String>,
60
61 #[serde(skip_serializing_if = "Option::is_none")]
63 pub is_error: Option<bool>,
64}
65
66impl ChatMessage {
67 pub fn user(content: impl Into<String>) -> Self {
69 Self {
70 role: "user".to_string(),
71 content: Some(content.into()),
72 ..Default::default()
73 }
74 }
75
76 pub fn assistant(content: impl Into<String>) -> Self {
78 Self {
79 role: "assistant".to_string(),
80 content: Some(content.into()),
81 ..Default::default()
82 }
83 }
84
85 pub fn system(content: impl Into<String>) -> Self {
87 Self {
88 role: "system".to_string(),
89 content: Some(content.into()),
90 ..Default::default()
91 }
92 }
93
94 pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
96 Self {
97 role: "tool".to_string(),
98 content: Some(content.into()),
99 tool_call_id: Some(tool_call_id.into()),
100 ..Default::default()
101 }
102 }
103
104 pub fn tool_error(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
106 Self {
107 role: "tool".to_string(),
108 content: Some(content.into()),
109 tool_call_id: Some(tool_call_id.into()),
110 is_error: Some(true),
111 ..Default::default()
112 }
113 }
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize, Default)]
118pub struct ContentBlock {
119 #[serde(rename = "type")]
121 pub block_type: String,
122
123 #[serde(skip_serializing_if = "Option::is_none")]
125 pub text: Option<String>,
126
127 #[serde(skip_serializing_if = "Option::is_none")]
129 pub id: Option<String>,
130
131 #[serde(skip_serializing_if = "Option::is_none")]
133 pub name: Option<String>,
134
135 #[serde(skip_serializing_if = "Option::is_none")]
137 pub input: Option<HashMap<String, serde_json::Value>>,
138
139 #[serde(skip_serializing_if = "Option::is_none")]
142 pub thought_signature: Option<String>,
143}
144
145#[derive(Debug, Clone, Serialize, Default)]
147pub struct ChatTool {
148 pub name: String,
150
151 pub description: String,
153
154 #[serde(skip_serializing_if = "Option::is_none")]
156 pub parameters: Option<serde_json::Value>,
157}
158
159#[derive(Debug, Clone, Deserialize)]
161pub struct ChatResponse {
162 pub id: String,
164
165 pub model: String,
167
168 #[serde(default)]
170 pub content: Vec<ContentBlock>,
171
172 pub usage: Option<ChatUsage>,
174
175 #[serde(default)]
177 pub stop_reason: String,
178
179 #[serde(default)]
181 pub citations: Vec<Citation>,
182
183 #[serde(skip)]
185 pub cost_ticks: i64,
186
187 #[serde(skip)]
189 pub request_id: String,
190}
191
192impl ChatResponse {
193 pub fn text(&self) -> String {
195 self.content
196 .iter()
197 .filter(|b| b.block_type == "text")
198 .filter_map(|b| b.text.as_deref())
199 .collect::<Vec<_>>()
200 .join("")
201 }
202
203 pub fn thinking(&self) -> String {
205 self.content
206 .iter()
207 .filter(|b| b.block_type == "thinking")
208 .filter_map(|b| b.text.as_deref())
209 .collect::<Vec<_>>()
210 .join("")
211 }
212
213 pub fn tool_calls(&self) -> Vec<&ContentBlock> {
215 self.content
216 .iter()
217 .filter(|b| b.block_type == "tool_use")
218 .collect()
219 }
220}
221
222#[derive(Debug, Clone, Deserialize, Serialize)]
224pub struct Citation {
225 #[serde(default)]
227 pub title: String,
228
229 #[serde(default)]
231 pub url: String,
232
233 #[serde(default)]
235 pub text: String,
236
237 #[serde(default)]
239 pub index: i32,
240}
241
242#[derive(Debug, Clone, Deserialize)]
244pub struct ChatUsage {
245 pub input_tokens: i32,
246 pub output_tokens: i32,
247 pub cost_ticks: i64,
248}
249
250#[derive(Debug, Clone)]
252pub struct StreamEvent {
253 pub event_type: String,
255
256 pub delta: Option<StreamDelta>,
258
259 pub tool_use: Option<StreamToolUse>,
261
262 pub usage: Option<ChatUsage>,
264
265 pub error: Option<String>,
267
268 pub done: bool,
270}
271
272#[derive(Debug, Clone, Deserialize)]
274pub struct StreamDelta {
275 pub text: String,
276}
277
278#[derive(Debug, Clone, Deserialize)]
280pub struct StreamToolUse {
281 pub id: String,
282 pub name: String,
283 pub input: HashMap<String, serde_json::Value>,
284}
285
286#[derive(Deserialize)]
288struct RawStreamEvent {
289 #[serde(rename = "type")]
290 event_type: String,
291 #[serde(default)]
292 delta: Option<StreamDelta>,
293 #[serde(default)]
294 id: Option<String>,
295 #[serde(default)]
296 name: Option<String>,
297 #[serde(default)]
298 input: Option<HashMap<String, serde_json::Value>>,
299 #[serde(default)]
300 input_tokens: Option<i32>,
301 #[serde(default)]
302 output_tokens: Option<i32>,
303 #[serde(default)]
304 cost_ticks: Option<i64>,
305 #[serde(default)]
306 message: Option<String>,
307}
308
309pin_project! {
310 pub struct ChatStream {
312 #[pin]
313 inner: Pin<Box<dyn Stream<Item = StreamEvent> + Send>>,
314 }
315}
316
317impl Stream for ChatStream {
318 type Item = StreamEvent;
319
320 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
321 self.project().inner.poll_next(cx)
322 }
323}
324
325impl Client {
326 pub async fn chat(&self, req: &ChatRequest) -> Result<ChatResponse> {
328 let mut req = req.clone();
329 req.stream = Some(false);
330
331 let (mut resp, meta) = self.post_json::<ChatRequest, ChatResponse>("/qai/v1/chat", &req).await?;
332 resp.cost_ticks = meta.cost_ticks;
333 resp.request_id = meta.request_id;
334 if resp.model.is_empty() {
335 resp.model = meta.model;
336 }
337 Ok(resp)
338 }
339
340 pub async fn chat_stream(&self, req: &ChatRequest) -> Result<ChatStream> {
364 let mut req = req.clone();
365 req.stream = Some(true);
366
367 let (resp, _meta) = self.post_stream_raw("/qai/v1/chat", &req).await?;
368
369 let byte_stream = resp.bytes_stream();
370 let event_stream = sse_to_events(byte_stream);
371
372 Ok(ChatStream {
373 inner: Box::pin(event_stream),
374 })
375 }
376}
377
378fn sse_to_events<S>(byte_stream: S) -> impl Stream<Item = StreamEvent> + Send
380where
381 S: Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
382{
383 let pinned_stream = Box::pin(byte_stream);
385
386 let line_stream = futures_util::stream::unfold(
388 (pinned_stream, String::new()),
389 |(mut stream, mut buffer)| async move {
390 use futures_util::StreamExt;
391 loop {
392 if let Some(newline_pos) = buffer.find('\n') {
394 let line = buffer[..newline_pos].trim_end_matches('\r').to_string();
395 buffer = buffer[newline_pos + 1..].to_string();
396 return Some((line, (stream, buffer)));
397 }
398
399 match stream.next().await {
401 Some(Ok(chunk)) => {
402 buffer.push_str(&String::from_utf8_lossy(&chunk));
403 }
404 Some(Err(_)) | None => {
405 if !buffer.is_empty() {
407 let remaining = std::mem::take(&mut buffer);
408 return Some((remaining, (stream, buffer)));
409 }
410 return None;
411 }
412 }
413 }
414 },
415 );
416
417 let pinned_lines = Box::pin(line_stream);
418 futures_util::stream::unfold(pinned_lines, |mut lines| async move {
419 use futures_util::StreamExt;
420 loop {
421 let line = lines.next().await?;
422
423 if !line.starts_with("data: ") {
424 continue;
425 }
426 let payload = &line["data: ".len()..];
427
428 if payload == "[DONE]" {
429 let ev = StreamEvent {
430 event_type: "done".to_string(),
431 delta: None,
432 tool_use: None,
433 usage: None,
434 error: None,
435 done: true,
436 };
437 return Some((ev, lines));
438 }
439
440 let raw: RawStreamEvent = match serde_json::from_str(payload) {
441 Ok(r) => r,
442 Err(e) => {
443 let ev = StreamEvent {
444 event_type: "error".to_string(),
445 delta: None,
446 tool_use: None,
447 usage: None,
448 error: Some(format!("parse SSE: {e}")),
449 done: false,
450 };
451 return Some((ev, lines));
452 }
453 };
454
455 let mut ev = StreamEvent {
456 event_type: raw.event_type.clone(),
457 delta: None,
458 tool_use: None,
459 usage: None,
460 error: None,
461 done: false,
462 };
463
464 match raw.event_type.as_str() {
465 "content_delta" | "thinking_delta" => {
466 ev.delta = raw.delta;
467 }
468 "tool_use" => {
469 ev.tool_use = Some(StreamToolUse {
470 id: raw.id.unwrap_or_default(),
471 name: raw.name.unwrap_or_default(),
472 input: raw.input.unwrap_or_default(),
473 });
474 }
475 "usage" => {
476 ev.usage = Some(ChatUsage {
477 input_tokens: raw.input_tokens.unwrap_or(0),
478 output_tokens: raw.output_tokens.unwrap_or(0),
479 cost_ticks: raw.cost_ticks.unwrap_or(0),
480 });
481 }
482 "error" => {
483 ev.error = raw.message;
484 }
485 "heartbeat" => {}
486 _ => {}
487 }
488
489 return Some((ev, lines));
490 }
491 })
492}