Skip to main content

quantum_sdk/
chat.rs

1use std::collections::HashMap;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures_util::Stream;
6use pin_project_lite::pin_project;
7use serde::{Deserialize, Serialize};
8
9use crate::client::Client;
10use crate::error::Result;
11
12/// Request body for text generation.
13#[derive(Debug, Clone, Serialize, Default)]
14pub struct ChatRequest {
15    /// Model ID that determines provider routing (e.g. "claude-sonnet-4-6", "grok-4-1-fast-non-reasoning").
16    pub model: String,
17
18    /// Conversation history.
19    pub messages: Vec<ChatMessage>,
20
21    /// Functions the model can call.
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub tools: Option<Vec<ChatTool>>,
24
25    /// Enables server-sent event streaming. Set automatically by `chat_stream`.
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub stream: Option<bool>,
28
29    /// Controls randomness (0.0-2.0).
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub temperature: Option<f64>,
32
33    /// Limits the response length.
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub max_tokens: Option<i32>,
36
37    /// Provider-specific settings (e.g. Anthropic thinking, xAI search).
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub provider_options: Option<HashMap<String, serde_json::Value>>,
40}
41
42/// A single message in a conversation.
43#[derive(Debug, Clone, Serialize, Deserialize, Default)]
44pub struct ChatMessage {
45    /// One of "system", "user", "assistant", or "tool".
46    pub role: String,
47
48    /// Text content of the message.
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub content: Option<String>,
51
52    /// Structured content for assistant messages with tool calls.
53    /// When present, takes precedence over `content`.
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub content_blocks: Option<Vec<ContentBlock>>,
56
57    /// Required when role is "tool" — references the tool_use ID.
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub tool_call_id: Option<String>,
60
61    /// Whether a tool result is an error.
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub is_error: Option<bool>,
64}
65
66impl ChatMessage {
67    /// Creates a user message.
68    pub fn user(content: impl Into<String>) -> Self {
69        Self {
70            role: "user".to_string(),
71            content: Some(content.into()),
72            ..Default::default()
73        }
74    }
75
76    /// Creates an assistant message.
77    pub fn assistant(content: impl Into<String>) -> Self {
78        Self {
79            role: "assistant".to_string(),
80            content: Some(content.into()),
81            ..Default::default()
82        }
83    }
84
85    /// Creates a system message.
86    pub fn system(content: impl Into<String>) -> Self {
87        Self {
88            role: "system".to_string(),
89            content: Some(content.into()),
90            ..Default::default()
91        }
92    }
93
94    /// Creates a tool result message.
95    pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
96        Self {
97            role: "tool".to_string(),
98            content: Some(content.into()),
99            tool_call_id: Some(tool_call_id.into()),
100            ..Default::default()
101        }
102    }
103
104    /// Creates a tool error result message.
105    pub fn tool_error(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
106        Self {
107            role: "tool".to_string(),
108            content: Some(content.into()),
109            tool_call_id: Some(tool_call_id.into()),
110            is_error: Some(true),
111            ..Default::default()
112        }
113    }
114}
115
116/// A single block in the response content array.
117#[derive(Debug, Clone, Serialize, Deserialize, Default)]
118pub struct ContentBlock {
119    /// One of "text", "thinking", or "tool_use".
120    #[serde(rename = "type")]
121    pub block_type: String,
122
123    /// Content for "text" and "thinking" blocks.
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub text: Option<String>,
126
127    /// Tool call identifier for "tool_use" blocks.
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub id: Option<String>,
130
131    /// Function name for "tool_use" blocks.
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub name: Option<String>,
134
135    /// Function arguments for "tool_use" blocks.
136    #[serde(skip_serializing_if = "Option::is_none")]
137    pub input: Option<HashMap<String, serde_json::Value>>,
138
139    /// Gemini thought signature — must be echoed back with tool results.
140    #[serde(skip_serializing_if = "Option::is_none")]
141    pub thought_signature: Option<String>,
142
143    /// Base64-encoded data for file/image content blocks.
144    #[serde(skip_serializing_if = "Option::is_none")]
145    pub data: Option<String>,
146
147    /// Filename for file content blocks.
148    #[serde(skip_serializing_if = "Option::is_none")]
149    pub file_name: Option<String>,
150
151    /// MIME type for file/image content blocks.
152    #[serde(skip_serializing_if = "Option::is_none")]
153    pub mime_type: Option<String>,
154}
155
156/// Defines a function the model can call.
157#[derive(Debug, Clone, Serialize, Default)]
158pub struct ChatTool {
159    /// Function name.
160    pub name: String,
161
162    /// Explains what the function does.
163    pub description: String,
164
165    /// JSON Schema for the function's arguments.
166    #[serde(skip_serializing_if = "Option::is_none")]
167    pub parameters: Option<serde_json::Value>,
168}
169
170/// Response from a non-streaming chat request.
171#[derive(Debug, Clone, Deserialize)]
172pub struct ChatResponse {
173    /// Unique request identifier.
174    pub id: String,
175
176    /// Model that generated the response.
177    pub model: String,
178
179    /// List of content blocks (text, thinking, tool_use).
180    #[serde(default)]
181    pub content: Vec<ContentBlock>,
182
183    /// Token counts and cost.
184    pub usage: Option<ChatUsage>,
185
186    /// Why generation stopped ("end_turn", "tool_use", "max_tokens").
187    #[serde(default)]
188    pub stop_reason: String,
189
190    /// Citations from web search (when search is enabled via provider_options).
191    #[serde(default)]
192    pub citations: Vec<Citation>,
193
194    /// Total cost from the X-QAI-Cost-Ticks header.
195    #[serde(skip)]
196    pub cost_ticks: i64,
197
198    /// From the X-QAI-Request-Id header.
199    #[serde(skip)]
200    pub request_id: String,
201}
202
203impl ChatResponse {
204    /// Returns the concatenated text content, ignoring thinking and tool_use blocks.
205    pub fn text(&self) -> String {
206        self.content
207            .iter()
208            .filter(|b| b.block_type == "text")
209            .filter_map(|b| b.text.as_deref())
210            .collect::<Vec<_>>()
211            .join("")
212    }
213
214    /// Returns the concatenated thinking content.
215    pub fn thinking(&self) -> String {
216        self.content
217            .iter()
218            .filter(|b| b.block_type == "thinking")
219            .filter_map(|b| b.text.as_deref())
220            .collect::<Vec<_>>()
221            .join("")
222    }
223
224    /// Returns all tool_use blocks from the response.
225    pub fn tool_calls(&self) -> Vec<&ContentBlock> {
226        self.content
227            .iter()
228            .filter(|b| b.block_type == "tool_use")
229            .collect()
230    }
231}
232
233/// A source reference from web search grounding.
234#[derive(Debug, Clone, Deserialize, Serialize)]
235pub struct Citation {
236    /// Title of the cited source.
237    #[serde(default)]
238    pub title: String,
239
240    /// URL of the cited source.
241    #[serde(default)]
242    pub url: String,
243
244    /// Relevant text snippet from the source.
245    #[serde(default)]
246    pub text: String,
247
248    /// Position in the response.
249    #[serde(default)]
250    pub index: i32,
251}
252
253/// Token counts and cost for a chat response.
254#[derive(Debug, Clone, Deserialize)]
255pub struct ChatUsage {
256    pub input_tokens: i32,
257    pub output_tokens: i32,
258    pub cost_ticks: i64,
259}
260
261/// A single event from an SSE chat stream.
262#[derive(Debug, Clone)]
263pub struct StreamEvent {
264    /// Event type: "content_delta", "thinking_delta", "tool_use", "usage", "heartbeat", "error", "done".
265    pub event_type: String,
266
267    /// Incremental text for content_delta and thinking_delta events.
268    pub delta: Option<StreamDelta>,
269
270    /// Populated for tool_use events.
271    pub tool_use: Option<StreamToolUse>,
272
273    /// Populated for usage events.
274    pub usage: Option<ChatUsage>,
275
276    /// Populated for error events.
277    pub error: Option<String>,
278
279    /// True when the stream is complete.
280    pub done: bool,
281}
282
283/// Incremental text in a streaming event.
284#[derive(Debug, Clone, Deserialize)]
285pub struct StreamDelta {
286    pub text: String,
287}
288
289/// A tool call from a streaming event.
290#[derive(Debug, Clone, Deserialize)]
291pub struct StreamToolUse {
292    pub id: String,
293    pub name: String,
294    pub input: HashMap<String, serde_json::Value>,
295}
296
297/// Raw JSON from the SSE stream before parsing into typed fields.
298#[derive(Deserialize)]
299struct RawStreamEvent {
300    #[serde(rename = "type")]
301    event_type: String,
302    #[serde(default)]
303    delta: Option<StreamDelta>,
304    #[serde(default)]
305    id: Option<String>,
306    #[serde(default)]
307    name: Option<String>,
308    #[serde(default)]
309    input: Option<HashMap<String, serde_json::Value>>,
310    #[serde(default)]
311    input_tokens: Option<i32>,
312    #[serde(default)]
313    output_tokens: Option<i32>,
314    #[serde(default)]
315    cost_ticks: Option<i64>,
316    #[serde(default)]
317    message: Option<String>,
318}
319
320pin_project! {
321    /// An async stream of [`StreamEvent`]s from an SSE chat response.
322    pub struct ChatStream {
323        #[pin]
324        inner: Pin<Box<dyn Stream<Item = StreamEvent> + Send>>,
325    }
326}
327
328impl Stream for ChatStream {
329    type Item = StreamEvent;
330
331    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
332        self.project().inner.poll_next(cx)
333    }
334}
335
336impl Client {
337    /// Sends a non-streaming text generation request.
338    pub async fn chat(&self, req: &ChatRequest) -> Result<ChatResponse> {
339        let mut req = req.clone();
340        req.stream = Some(false);
341
342        let (mut resp, meta) = self.post_json::<ChatRequest, ChatResponse>("/qai/v1/chat", &req).await?;
343        resp.cost_ticks = meta.cost_ticks;
344        resp.request_id = meta.request_id;
345        if resp.model.is_empty() {
346            resp.model = meta.model;
347        }
348        Ok(resp)
349    }
350
351    /// Sends a streaming text generation request and returns an async stream of events.
352    ///
353    /// # Example
354    ///
355    /// ```no_run
356    /// use futures_util::StreamExt;
357    ///
358    /// # async fn example() -> quantum_sdk::Result<()> {
359    /// let client = quantum_sdk::Client::new("key");
360    /// let req = quantum_sdk::ChatRequest {
361    ///     model: "claude-sonnet-4-6".into(),
362    ///     messages: vec![quantum_sdk::ChatMessage::user("Hello!")],
363    ///     ..Default::default()
364    /// };
365    /// let mut stream = client.chat_stream(&req).await?;
366    /// while let Some(ev) = stream.next().await {
367    ///     if let Some(delta) = &ev.delta {
368    ///         print!("{}", delta.text);
369    ///     }
370    /// }
371    /// # Ok(())
372    /// # }
373    /// ```
374    pub async fn chat_stream(&self, req: &ChatRequest) -> Result<ChatStream> {
375        let mut req = req.clone();
376        req.stream = Some(true);
377
378        let (resp, _meta) = self.post_stream_raw("/qai/v1/chat", &req).await?;
379
380        let byte_stream = resp.bytes_stream();
381        let event_stream = sse_to_events(byte_stream);
382
383        Ok(ChatStream {
384            inner: Box::pin(event_stream),
385        })
386    }
387}
388
389/// Converts a byte stream into a stream of parsed [`StreamEvent`]s.
390fn sse_to_events<S>(byte_stream: S) -> impl Stream<Item = StreamEvent> + Send
391where
392    S: Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
393{
394    // Pin the byte stream so we can poll it inside unfold.
395    let pinned_stream = Box::pin(byte_stream);
396
397    // We accumulate bytes into lines, then parse SSE "data: " lines.
398    let line_stream = futures_util::stream::unfold(
399        (pinned_stream, String::new()),
400        |(mut stream, mut buffer)| async move {
401            use futures_util::StreamExt;
402            loop {
403                // Check if we have a complete line in the buffer.
404                if let Some(newline_pos) = buffer.find('\n') {
405                    let line = buffer[..newline_pos].trim_end_matches('\r').to_string();
406                    buffer = buffer[newline_pos + 1..].to_string();
407                    return Some((line, (stream, buffer)));
408                }
409
410                // Read more data.
411                match stream.next().await {
412                    Some(Ok(chunk)) => {
413                        buffer.push_str(&String::from_utf8_lossy(&chunk));
414                    }
415                    Some(Err(_)) | None => {
416                        // Stream ended. Emit remaining buffer if non-empty.
417                        if !buffer.is_empty() {
418                            let remaining = std::mem::take(&mut buffer);
419                            return Some((remaining, (stream, buffer)));
420                        }
421                        return None;
422                    }
423                }
424            }
425        },
426    );
427
428    let pinned_lines = Box::pin(line_stream);
429    futures_util::stream::unfold(pinned_lines, |mut lines| async move {
430        use futures_util::StreamExt;
431        loop {
432            let line = lines.next().await?;
433
434            if !line.starts_with("data: ") {
435                continue;
436            }
437            let payload = &line["data: ".len()..];
438
439            if payload == "[DONE]" {
440                let ev = StreamEvent {
441                    event_type: "done".to_string(),
442                    delta: None,
443                    tool_use: None,
444                    usage: None,
445                    error: None,
446                    done: true,
447                };
448                return Some((ev, lines));
449            }
450
451            let raw: RawStreamEvent = match serde_json::from_str(payload) {
452                Ok(r) => r,
453                Err(e) => {
454                    let ev = StreamEvent {
455                        event_type: "error".to_string(),
456                        delta: None,
457                        tool_use: None,
458                        usage: None,
459                        error: Some(format!("parse SSE: {e}")),
460                        done: false,
461                    };
462                    return Some((ev, lines));
463                }
464            };
465
466            let mut ev = StreamEvent {
467                event_type: raw.event_type.clone(),
468                delta: None,
469                tool_use: None,
470                usage: None,
471                error: None,
472                done: false,
473            };
474
475            match raw.event_type.as_str() {
476                "content_delta" | "thinking_delta" => {
477                    ev.delta = raw.delta;
478                }
479                "tool_use" => {
480                    ev.tool_use = Some(StreamToolUse {
481                        id: raw.id.unwrap_or_default(),
482                        name: raw.name.unwrap_or_default(),
483                        input: raw.input.unwrap_or_default(),
484                    });
485                }
486                "usage" => {
487                    ev.usage = Some(ChatUsage {
488                        input_tokens: raw.input_tokens.unwrap_or(0),
489                        output_tokens: raw.output_tokens.unwrap_or(0),
490                        cost_ticks: raw.cost_ticks.unwrap_or(0),
491                    });
492                }
493                "error" => {
494                    ev.error = raw.message;
495                }
496                "heartbeat" => {}
497                _ => {}
498            }
499
500            return Some((ev, lines));
501        }
502    })
503}