1use std::collections::HashMap;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures_util::Stream;
6use pin_project_lite::pin_project;
7use serde::{Deserialize, Serialize};
8
9use crate::client::Client;
10use crate::error::Result;
11
12#[derive(Debug, Clone, Serialize, Default)]
14pub struct ChatRequest {
15 pub model: String,
17
18 pub messages: Vec<ChatMessage>,
20
21 #[serde(skip_serializing_if = "Option::is_none")]
23 pub tools: Option<Vec<ChatTool>>,
24
25 #[serde(skip_serializing_if = "Option::is_none")]
27 pub stream: Option<bool>,
28
29 #[serde(skip_serializing_if = "Option::is_none")]
31 pub temperature: Option<f64>,
32
33 #[serde(skip_serializing_if = "Option::is_none")]
35 pub max_tokens: Option<i32>,
36
37 #[serde(skip_serializing_if = "Option::is_none")]
39 pub provider_options: Option<HashMap<String, serde_json::Value>>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize, Default)]
44pub struct ChatMessage {
45 pub role: String,
47
48 #[serde(skip_serializing_if = "Option::is_none")]
50 pub content: Option<String>,
51
52 #[serde(skip_serializing_if = "Option::is_none")]
55 pub content_blocks: Option<Vec<ContentBlock>>,
56
57 #[serde(skip_serializing_if = "Option::is_none")]
59 pub tool_call_id: Option<String>,
60
61 #[serde(skip_serializing_if = "Option::is_none")]
63 pub is_error: Option<bool>,
64}
65
66impl ChatMessage {
67 pub fn user(content: impl Into<String>) -> Self {
69 Self {
70 role: "user".to_string(),
71 content: Some(content.into()),
72 ..Default::default()
73 }
74 }
75
76 pub fn assistant(content: impl Into<String>) -> Self {
78 Self {
79 role: "assistant".to_string(),
80 content: Some(content.into()),
81 ..Default::default()
82 }
83 }
84
85 pub fn system(content: impl Into<String>) -> Self {
87 Self {
88 role: "system".to_string(),
89 content: Some(content.into()),
90 ..Default::default()
91 }
92 }
93
94 pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
96 Self {
97 role: "tool".to_string(),
98 content: Some(content.into()),
99 tool_call_id: Some(tool_call_id.into()),
100 ..Default::default()
101 }
102 }
103
104 pub fn tool_error(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
106 Self {
107 role: "tool".to_string(),
108 content: Some(content.into()),
109 tool_call_id: Some(tool_call_id.into()),
110 is_error: Some(true),
111 ..Default::default()
112 }
113 }
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize, Default)]
118pub struct ContentBlock {
119 #[serde(rename = "type")]
121 pub block_type: String,
122
123 #[serde(skip_serializing_if = "Option::is_none")]
125 pub text: Option<String>,
126
127 #[serde(skip_serializing_if = "Option::is_none")]
129 pub id: Option<String>,
130
131 #[serde(skip_serializing_if = "Option::is_none")]
133 pub name: Option<String>,
134
135 #[serde(skip_serializing_if = "Option::is_none")]
137 pub input: Option<HashMap<String, serde_json::Value>>,
138
139 #[serde(skip_serializing_if = "Option::is_none")]
141 pub thought_signature: Option<String>,
142
143 #[serde(skip_serializing_if = "Option::is_none")]
145 pub data: Option<String>,
146
147 #[serde(skip_serializing_if = "Option::is_none")]
149 pub file_name: Option<String>,
150
151 #[serde(skip_serializing_if = "Option::is_none")]
153 pub mime_type: Option<String>,
154}
155
156#[derive(Debug, Clone, Serialize, Default)]
158pub struct ChatTool {
159 pub name: String,
161
162 pub description: String,
164
165 #[serde(skip_serializing_if = "Option::is_none")]
167 pub parameters: Option<serde_json::Value>,
168}
169
170#[derive(Debug, Clone, Deserialize)]
172pub struct ChatResponse {
173 pub id: String,
175
176 pub model: String,
178
179 #[serde(default)]
181 pub content: Vec<ContentBlock>,
182
183 pub usage: Option<ChatUsage>,
185
186 #[serde(default)]
188 pub stop_reason: String,
189
190 #[serde(default)]
192 pub citations: Vec<Citation>,
193
194 #[serde(skip)]
196 pub cost_ticks: i64,
197
198 #[serde(skip)]
200 pub request_id: String,
201}
202
203impl ChatResponse {
204 pub fn text(&self) -> String {
206 self.content
207 .iter()
208 .filter(|b| b.block_type == "text")
209 .filter_map(|b| b.text.as_deref())
210 .collect::<Vec<_>>()
211 .join("")
212 }
213
214 pub fn thinking(&self) -> String {
216 self.content
217 .iter()
218 .filter(|b| b.block_type == "thinking")
219 .filter_map(|b| b.text.as_deref())
220 .collect::<Vec<_>>()
221 .join("")
222 }
223
224 pub fn tool_calls(&self) -> Vec<&ContentBlock> {
226 self.content
227 .iter()
228 .filter(|b| b.block_type == "tool_use")
229 .collect()
230 }
231}
232
233#[derive(Debug, Clone, Deserialize, Serialize)]
235pub struct Citation {
236 #[serde(default)]
238 pub title: String,
239
240 #[serde(default)]
242 pub url: String,
243
244 #[serde(default)]
246 pub text: String,
247
248 #[serde(default)]
250 pub index: i32,
251}
252
253#[derive(Debug, Clone, Deserialize)]
255pub struct ChatUsage {
256 pub input_tokens: i32,
257 pub output_tokens: i32,
258 pub cost_ticks: i64,
259}
260
261#[derive(Debug, Clone)]
263pub struct StreamEvent {
264 pub event_type: String,
266
267 pub delta: Option<StreamDelta>,
269
270 pub tool_use: Option<StreamToolUse>,
272
273 pub usage: Option<ChatUsage>,
275
276 pub error: Option<String>,
278
279 pub done: bool,
281}
282
283#[derive(Debug, Clone, Deserialize)]
285pub struct StreamDelta {
286 pub text: String,
287}
288
289#[derive(Debug, Clone, Deserialize)]
291pub struct StreamToolUse {
292 pub id: String,
293 pub name: String,
294 pub input: HashMap<String, serde_json::Value>,
295}
296
297#[derive(Deserialize)]
299struct RawStreamEvent {
300 #[serde(rename = "type")]
301 event_type: String,
302 #[serde(default)]
303 delta: Option<StreamDelta>,
304 #[serde(default)]
305 id: Option<String>,
306 #[serde(default)]
307 name: Option<String>,
308 #[serde(default)]
309 input: Option<HashMap<String, serde_json::Value>>,
310 #[serde(default)]
311 input_tokens: Option<i32>,
312 #[serde(default)]
313 output_tokens: Option<i32>,
314 #[serde(default)]
315 cost_ticks: Option<i64>,
316 #[serde(default)]
317 message: Option<String>,
318}
319
320pin_project! {
321 pub struct ChatStream {
323 #[pin]
324 inner: Pin<Box<dyn Stream<Item = StreamEvent> + Send>>,
325 }
326}
327
328impl Stream for ChatStream {
329 type Item = StreamEvent;
330
331 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
332 self.project().inner.poll_next(cx)
333 }
334}
335
336impl Client {
337 pub async fn chat(&self, req: &ChatRequest) -> Result<ChatResponse> {
339 let mut req = req.clone();
340 req.stream = Some(false);
341
342 let (mut resp, meta) = self.post_json::<ChatRequest, ChatResponse>("/qai/v1/chat", &req).await?;
343 resp.cost_ticks = meta.cost_ticks;
344 resp.request_id = meta.request_id;
345 if resp.model.is_empty() {
346 resp.model = meta.model;
347 }
348 Ok(resp)
349 }
350
351 pub async fn chat_stream(&self, req: &ChatRequest) -> Result<ChatStream> {
375 let mut req = req.clone();
376 req.stream = Some(true);
377
378 let (resp, _meta) = self.post_stream_raw("/qai/v1/chat", &req).await?;
379
380 let byte_stream = resp.bytes_stream();
381 let event_stream = sse_to_events(byte_stream);
382
383 Ok(ChatStream {
384 inner: Box::pin(event_stream),
385 })
386 }
387}
388
389fn sse_to_events<S>(byte_stream: S) -> impl Stream<Item = StreamEvent> + Send
391where
392 S: Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
393{
394 let pinned_stream = Box::pin(byte_stream);
396
397 let line_stream = futures_util::stream::unfold(
399 (pinned_stream, String::new()),
400 |(mut stream, mut buffer)| async move {
401 use futures_util::StreamExt;
402 loop {
403 if let Some(newline_pos) = buffer.find('\n') {
405 let line = buffer[..newline_pos].trim_end_matches('\r').to_string();
406 buffer = buffer[newline_pos + 1..].to_string();
407 return Some((line, (stream, buffer)));
408 }
409
410 match stream.next().await {
412 Some(Ok(chunk)) => {
413 buffer.push_str(&String::from_utf8_lossy(&chunk));
414 }
415 Some(Err(_)) | None => {
416 if !buffer.is_empty() {
418 let remaining = std::mem::take(&mut buffer);
419 return Some((remaining, (stream, buffer)));
420 }
421 return None;
422 }
423 }
424 }
425 },
426 );
427
428 let pinned_lines = Box::pin(line_stream);
429 futures_util::stream::unfold(pinned_lines, |mut lines| async move {
430 use futures_util::StreamExt;
431 loop {
432 let line = lines.next().await?;
433
434 if !line.starts_with("data: ") {
435 continue;
436 }
437 let payload = &line["data: ".len()..];
438
439 if payload == "[DONE]" {
440 let ev = StreamEvent {
441 event_type: "done".to_string(),
442 delta: None,
443 tool_use: None,
444 usage: None,
445 error: None,
446 done: true,
447 };
448 return Some((ev, lines));
449 }
450
451 let raw: RawStreamEvent = match serde_json::from_str(payload) {
452 Ok(r) => r,
453 Err(e) => {
454 let ev = StreamEvent {
455 event_type: "error".to_string(),
456 delta: None,
457 tool_use: None,
458 usage: None,
459 error: Some(format!("parse SSE: {e}")),
460 done: false,
461 };
462 return Some((ev, lines));
463 }
464 };
465
466 let mut ev = StreamEvent {
467 event_type: raw.event_type.clone(),
468 delta: None,
469 tool_use: None,
470 usage: None,
471 error: None,
472 done: false,
473 };
474
475 match raw.event_type.as_str() {
476 "content_delta" | "thinking_delta" => {
477 ev.delta = raw.delta;
478 }
479 "tool_use" => {
480 ev.tool_use = Some(StreamToolUse {
481 id: raw.id.unwrap_or_default(),
482 name: raw.name.unwrap_or_default(),
483 input: raw.input.unwrap_or_default(),
484 });
485 }
486 "usage" => {
487 ev.usage = Some(ChatUsage {
488 input_tokens: raw.input_tokens.unwrap_or(0),
489 output_tokens: raw.output_tokens.unwrap_or(0),
490 cost_ticks: raw.cost_ticks.unwrap_or(0),
491 });
492 }
493 "error" => {
494 ev.error = raw.message;
495 }
496 "heartbeat" => {}
497 _ => {}
498 }
499
500 return Some((ev, lines));
501 }
502 })
503}