1use std::collections::HashMap;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures_util::Stream;
6use pin_project_lite::pin_project;
7use serde::{Deserialize, Serialize};
8
9use crate::client::Client;
10use crate::error::Result;
11
12fn null_as_empty_vec<'de, D, T>(deserializer: D) -> std::result::Result<Vec<T>, D::Error>
14where
15 D: serde::Deserializer<'de>,
16 T: Deserialize<'de>,
17{
18 Option::<Vec<T>>::deserialize(deserializer).map(|v| v.unwrap_or_default())
19}
20
21fn deserialize_opt_vec<'de, D, T>(deserializer: D) -> std::result::Result<Option<Vec<T>>, D::Error>
23where
24 D: serde::Deserializer<'de>,
25 T: Deserialize<'de>,
26{
27 Ok(Option::<Vec<T>>::deserialize(deserializer).unwrap_or(None))
29}
30
31#[derive(Debug, Clone, Serialize, Default)]
33pub struct ChatRequest {
34 pub model: String,
36
37 pub messages: Vec<ChatMessage>,
39
40 #[serde(skip_serializing_if = "Option::is_none")]
42 pub tools: Option<Vec<ChatTool>>,
43
44 #[serde(skip_serializing_if = "Option::is_none")]
46 pub tool_choice: Option<String>,
47
48 #[serde(skip_serializing_if = "Option::is_none")]
50 pub output_schema: Option<serde_json::Value>,
51
52 #[serde(skip_serializing_if = "Option::is_none")]
54 pub stream: Option<bool>,
55
56 #[serde(skip_serializing_if = "Option::is_none")]
58 pub temperature: Option<f64>,
59
60 #[serde(skip_serializing_if = "Option::is_none")]
62 pub max_tokens: Option<i32>,
63
64 #[serde(skip_serializing_if = "Option::is_none")]
66 pub provider_options: Option<HashMap<String, serde_json::Value>>,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize, Default)]
71pub struct ChatMessage {
72 pub role: String,
74
75 #[serde(skip_serializing_if = "Option::is_none")]
77 pub content: Option<String>,
78
79 #[serde(skip_serializing_if = "Option::is_none", deserialize_with = "deserialize_opt_vec", default)]
82 pub content_blocks: Option<Vec<ContentBlock>>,
83
84 #[serde(skip_serializing_if = "Option::is_none")]
86 pub tool_call_id: Option<String>,
87
88 #[serde(skip_serializing_if = "Option::is_none")]
90 pub is_error: Option<bool>,
91}
92
93impl ChatMessage {
94 pub fn user(content: impl Into<String>) -> Self {
96 Self {
97 role: "user".to_string(),
98 content: Some(content.into()),
99 ..Default::default()
100 }
101 }
102
103 pub fn assistant(content: impl Into<String>) -> Self {
105 Self {
106 role: "assistant".to_string(),
107 content: Some(content.into()),
108 ..Default::default()
109 }
110 }
111
112 pub fn system(content: impl Into<String>) -> Self {
114 Self {
115 role: "system".to_string(),
116 content: Some(content.into()),
117 ..Default::default()
118 }
119 }
120
121 pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
123 Self {
124 role: "tool".to_string(),
125 content: Some(content.into()),
126 tool_call_id: Some(tool_call_id.into()),
127 ..Default::default()
128 }
129 }
130
131 pub fn tool_error(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
133 Self {
134 role: "tool".to_string(),
135 content: Some(content.into()),
136 tool_call_id: Some(tool_call_id.into()),
137 is_error: Some(true),
138 ..Default::default()
139 }
140 }
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize, Default)]
145pub struct ContentBlock {
146 #[serde(rename = "type")]
148 pub block_type: String,
149
150 #[serde(skip_serializing_if = "Option::is_none")]
152 pub text: Option<String>,
153
154 #[serde(skip_serializing_if = "Option::is_none")]
156 pub id: Option<String>,
157
158 #[serde(skip_serializing_if = "Option::is_none")]
160 pub name: Option<String>,
161
162 #[serde(skip_serializing_if = "Option::is_none")]
164 pub input: Option<HashMap<String, serde_json::Value>>,
165
166 #[serde(skip_serializing_if = "Option::is_none")]
168 pub thought_signature: Option<String>,
169
170 #[serde(skip_serializing_if = "Option::is_none")]
172 pub data: Option<String>,
173
174 #[serde(skip_serializing_if = "Option::is_none")]
176 pub file_name: Option<String>,
177
178 #[serde(skip_serializing_if = "Option::is_none")]
180 pub mime_type: Option<String>,
181
182 #[serde(skip_serializing_if = "Option::is_none")]
188 pub file_uri: Option<String>,
189}
190
191#[derive(Debug, Clone, Serialize, Default)]
193pub struct ChatTool {
194 pub name: String,
196
197 pub description: String,
199
200 #[serde(skip_serializing_if = "Option::is_none")]
202 pub parameters: Option<serde_json::Value>,
203
204 #[serde(skip_serializing_if = "Option::is_none")]
206 pub strict: Option<bool>,
207}
208
209#[derive(Debug, Clone, Deserialize)]
211pub struct ChatResponse {
212 pub id: String,
214
215 pub model: String,
217
218 #[serde(default, deserialize_with = "null_as_empty_vec")]
220 pub content: Vec<ContentBlock>,
221
222 pub usage: Option<ChatUsage>,
224
225 #[serde(default)]
227 pub stop_reason: String,
228
229 #[serde(default, deserialize_with = "null_as_empty_vec")]
231 pub citations: Vec<Citation>,
232
233 #[serde(skip)]
235 pub cost_ticks: i64,
236
237 #[serde(skip)]
239 pub request_id: String,
240}
241
242impl ChatResponse {
243 pub fn text(&self) -> String {
245 self.content
246 .iter()
247 .filter(|b| b.block_type == "text")
248 .filter_map(|b| b.text.as_deref())
249 .collect::<Vec<_>>()
250 .join("")
251 }
252
253 pub fn thinking(&self) -> String {
255 self.content
256 .iter()
257 .filter(|b| b.block_type == "thinking")
258 .filter_map(|b| b.text.as_deref())
259 .collect::<Vec<_>>()
260 .join("")
261 }
262
263 pub fn tool_calls(&self) -> Vec<&ContentBlock> {
265 self.content
266 .iter()
267 .filter(|b| b.block_type == "tool_use")
268 .collect()
269 }
270}
271
272#[derive(Debug, Clone, Deserialize, Serialize)]
274pub struct Citation {
275 #[serde(default)]
277 pub title: String,
278
279 #[serde(default)]
281 pub url: String,
282
283 #[serde(default)]
285 pub text: String,
286
287 #[serde(default)]
289 pub index: i32,
290}
291
292#[derive(Debug, Clone, Deserialize)]
294pub struct ChatUsage {
295 pub input_tokens: i32,
296 pub output_tokens: i32,
297 pub cost_ticks: i64,
298}
299
300#[derive(Debug, Clone)]
309pub struct StreamEvent {
310 pub event_type: String,
314
315 pub delta: Option<StreamDelta>,
317
318 pub tool_use: Option<StreamToolUse>,
320
321 pub tool_use_start: Option<StreamToolUseStart>,
323
324 pub tool_use_input_delta: Option<StreamToolUseInputDelta>,
326
327 pub tool_use_complete: Option<StreamToolUseComplete>,
329
330 pub usage: Option<ChatUsage>,
332
333 pub error: Option<String>,
335
336 pub done: bool,
338}
339
340#[derive(Debug, Clone, Deserialize)]
342pub struct StreamDelta {
343 pub text: String,
344}
345
346#[derive(Debug, Clone, Deserialize)]
348pub struct StreamToolUse {
349 pub id: String,
350 pub name: String,
351 pub input: HashMap<String, serde_json::Value>,
352}
353
354#[derive(Debug, Clone, Deserialize)]
356pub struct StreamToolUseStart {
357 pub id: String,
358 pub name: String,
359}
360
361#[derive(Debug, Clone, Deserialize)]
363pub struct StreamToolUseInputDelta {
364 pub id: String,
365 pub partial_json: String,
369}
370
371#[derive(Debug, Clone, Deserialize)]
374pub struct StreamToolUseComplete {
375 pub id: String,
376 pub name: String,
377 pub input: HashMap<String, serde_json::Value>,
378}
379
380#[derive(Deserialize)]
382struct RawStreamEvent {
383 #[serde(rename = "type")]
384 event_type: String,
385 #[serde(default)]
386 delta: Option<StreamDelta>,
387 #[serde(default)]
388 id: Option<String>,
389 #[serde(default)]
390 name: Option<String>,
391 #[serde(default)]
392 input: Option<HashMap<String, serde_json::Value>>,
393 #[serde(default)]
395 partial_json: Option<String>,
396 #[serde(default)]
397 input_tokens: Option<i32>,
398 #[serde(default)]
399 output_tokens: Option<i32>,
400 #[serde(default)]
401 cost_ticks: Option<i64>,
402 #[serde(default)]
403 message: Option<String>,
404}
405
406pin_project! {
407 pub struct ChatStream {
409 #[pin]
410 inner: Pin<Box<dyn Stream<Item = StreamEvent> + Send>>,
411 }
412}
413
414impl Stream for ChatStream {
415 type Item = StreamEvent;
416
417 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
418 self.project().inner.poll_next(cx)
419 }
420}
421
422impl Client {
423 pub async fn chat(&self, req: &ChatRequest) -> Result<ChatResponse> {
425 let mut req = req.clone();
426 req.stream = Some(false);
427
428 let (mut resp, meta) = self.post_json::<ChatRequest, ChatResponse>("/qai/v1/chat", &req).await?;
429 resp.cost_ticks = meta.cost_ticks;
430 resp.request_id = meta.request_id;
431 if resp.model.is_empty() {
432 resp.model = meta.model;
433 }
434 Ok(resp)
435 }
436
437 pub async fn chat_stream(&self, req: &ChatRequest) -> Result<ChatStream> {
461 let mut req = req.clone();
462 req.stream = Some(true);
463
464 let (resp, _meta) = self.post_stream_raw("/qai/v1/chat", &req).await?;
465
466 let byte_stream = resp.bytes_stream();
467 let event_stream = sse_to_events(byte_stream);
468
469 Ok(ChatStream {
470 inner: Box::pin(event_stream),
471 })
472 }
473}
474
475fn sse_to_events<S>(byte_stream: S) -> impl Stream<Item = StreamEvent> + Send
477where
478 S: Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
479{
480 let pinned_stream = Box::pin(byte_stream);
482
483 let line_stream = futures_util::stream::unfold(
486 (pinned_stream, Vec::<u8>::new()),
487 |(mut stream, mut buffer)| async move {
488 use futures_util::StreamExt;
489 loop {
490 if let Some(newline_pos) = buffer.iter().position(|&b| b == b'\n') {
492 let mut line_bytes = buffer[..newline_pos].to_vec();
493 buffer = buffer[newline_pos + 1..].to_vec();
494 if line_bytes.last() == Some(&b'\r') {
496 line_bytes.pop();
497 }
498 let line = String::from_utf8_lossy(&line_bytes).into_owned();
499 return Some((line, (stream, buffer)));
500 }
501
502 match stream.next().await {
504 Some(Ok(chunk)) => {
505 buffer.extend_from_slice(&chunk);
506 }
507 Some(Err(_)) | None => {
508 if !buffer.is_empty() {
510 let remaining = String::from_utf8_lossy(&buffer).into_owned();
511 buffer.clear();
512 return Some((remaining, (stream, buffer)));
513 }
514 return None;
515 }
516 }
517 }
518 },
519 );
520
521 let pinned_lines = Box::pin(line_stream);
522 futures_util::stream::unfold(pinned_lines, |mut lines| async move {
523 use futures_util::StreamExt;
524 loop {
525 let line = lines.next().await?;
526
527 if !line.starts_with("data: ") {
528 continue;
529 }
530 let payload = &line["data: ".len()..];
531
532 if payload == "[DONE]" {
533 let ev = StreamEvent {
534 event_type: "done".to_string(),
535 delta: None,
536 tool_use: None,
537 tool_use_start: None,
538 tool_use_input_delta: None,
539 tool_use_complete: None,
540 usage: None,
541 error: None,
542 done: true,
543 };
544 return Some((ev, lines));
545 }
546
547 let raw: RawStreamEvent = match serde_json::from_str(payload) {
548 Ok(r) => r,
549 Err(e) => {
550 let ev = StreamEvent {
551 event_type: "error".to_string(),
552 delta: None,
553 tool_use: None,
554 tool_use_start: None,
555 tool_use_input_delta: None,
556 tool_use_complete: None,
557 usage: None,
558 error: Some(format!("parse SSE: {e}")),
559 done: false,
560 };
561 return Some((ev, lines));
562 }
563 };
564
565 let mut ev = StreamEvent {
566 event_type: raw.event_type.clone(),
567 delta: None,
568 tool_use: None,
569 tool_use_start: None,
570 tool_use_input_delta: None,
571 tool_use_complete: None,
572 usage: None,
573 error: None,
574 done: false,
575 };
576
577 match raw.event_type.as_str() {
578 "content_delta" | "thinking_delta" => {
579 ev.delta = raw.delta;
580 }
581 "tool_use" => {
582 ev.tool_use = Some(StreamToolUse {
585 id: raw.id.unwrap_or_default(),
586 name: raw.name.unwrap_or_default(),
587 input: raw.input.unwrap_or_default(),
588 });
589 }
590 "tool_use_start" => {
591 ev.tool_use_start = Some(StreamToolUseStart {
592 id: raw.id.unwrap_or_default(),
593 name: raw.name.unwrap_or_default(),
594 });
595 }
596 "tool_use_input_delta" => {
597 ev.tool_use_input_delta = Some(StreamToolUseInputDelta {
598 id: raw.id.unwrap_or_default(),
599 partial_json: raw.partial_json.unwrap_or_default(),
600 });
601 }
602 "tool_use_complete" => {
603 ev.tool_use_complete = Some(StreamToolUseComplete {
604 id: raw.id.unwrap_or_default(),
605 name: raw.name.unwrap_or_default(),
606 input: raw.input.unwrap_or_default(),
607 });
608 }
609 "usage" => {
610 ev.usage = Some(ChatUsage {
611 input_tokens: raw.input_tokens.unwrap_or(0),
612 output_tokens: raw.output_tokens.unwrap_or(0),
613 cost_ticks: raw.cost_ticks.unwrap_or(0),
614 });
615 }
616 "error" => {
617 ev.error = raw.message;
618 }
619 "heartbeat" => {}
620 _ => {}
621 }
622
623 return Some((ev, lines));
624 }
625 })
626}