use std::collections::HashMap;
use std::pin::Pin;
use std::task::{Context, Poll};
use futures_util::Stream;
use pin_project_lite::pin_project;
use serde::{Deserialize, Serialize};
use crate::client::Client;
use crate::error::Result;
fn null_as_empty_vec<'de, D, T>(deserializer: D) -> std::result::Result<Vec<T>, D::Error>
where
D: serde::Deserializer<'de>,
T: Deserialize<'de>,
{
Option::<Vec<T>>::deserialize(deserializer).map(|v| v.unwrap_or_default())
}
fn deserialize_opt_vec<'de, D, T>(deserializer: D) -> std::result::Result<Option<Vec<T>>, D::Error>
where
D: serde::Deserializer<'de>,
T: Deserialize<'de>,
{
Ok(Option::<Vec<T>>::deserialize(deserializer).unwrap_or(None))
}
#[derive(Debug, Clone, Serialize, Default)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ChatTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_schema: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub provider_options: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ChatMessage {
pub role: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", deserialize_with = "deserialize_opt_vec", default)]
pub content_blocks: Option<Vec<ContentBlock>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub is_error: Option<bool>,
}
impl ChatMessage {
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".to_string(),
content: Some(content.into()),
..Default::default()
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".to_string(),
content: Some(content.into()),
..Default::default()
}
}
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".to_string(),
content: Some(content.into()),
..Default::default()
}
}
pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: "tool".to_string(),
content: Some(content.into()),
tool_call_id: Some(tool_call_id.into()),
..Default::default()
}
}
pub fn tool_error(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: "tool".to_string(),
content: Some(content.into()),
tool_call_id: Some(tool_call_id.into()),
is_error: Some(true),
..Default::default()
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ContentBlock {
#[serde(rename = "type")]
pub block_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub input: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thought_signature: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub file_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub mime_type: Option<String>,
}
#[derive(Debug, Clone, Serialize, Default)]
pub struct ChatTool {
pub name: String,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ChatResponse {
pub id: String,
pub model: String,
#[serde(default, deserialize_with = "null_as_empty_vec")]
pub content: Vec<ContentBlock>,
pub usage: Option<ChatUsage>,
#[serde(default)]
pub stop_reason: String,
#[serde(default, deserialize_with = "null_as_empty_vec")]
pub citations: Vec<Citation>,
#[serde(skip)]
pub cost_ticks: i64,
#[serde(skip)]
pub request_id: String,
}
impl ChatResponse {
pub fn text(&self) -> String {
self.content
.iter()
.filter(|b| b.block_type == "text")
.filter_map(|b| b.text.as_deref())
.collect::<Vec<_>>()
.join("")
}
pub fn thinking(&self) -> String {
self.content
.iter()
.filter(|b| b.block_type == "thinking")
.filter_map(|b| b.text.as_deref())
.collect::<Vec<_>>()
.join("")
}
pub fn tool_calls(&self) -> Vec<&ContentBlock> {
self.content
.iter()
.filter(|b| b.block_type == "tool_use")
.collect()
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Citation {
#[serde(default)]
pub title: String,
#[serde(default)]
pub url: String,
#[serde(default)]
pub text: String,
#[serde(default)]
pub index: i32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ChatUsage {
pub input_tokens: i32,
pub output_tokens: i32,
pub cost_ticks: i64,
}
#[derive(Debug, Clone)]
pub struct StreamEvent {
pub event_type: String,
pub delta: Option<StreamDelta>,
pub tool_use: Option<StreamToolUse>,
pub tool_use_start: Option<StreamToolUseStart>,
pub tool_use_input_delta: Option<StreamToolUseInputDelta>,
pub tool_use_complete: Option<StreamToolUseComplete>,
pub usage: Option<ChatUsage>,
pub error: Option<String>,
pub done: bool,
}
#[derive(Debug, Clone, Deserialize)]
pub struct StreamDelta {
pub text: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct StreamToolUse {
pub id: String,
pub name: String,
pub input: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct StreamToolUseStart {
pub id: String,
pub name: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct StreamToolUseInputDelta {
pub id: String,
pub partial_json: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct StreamToolUseComplete {
pub id: String,
pub name: String,
pub input: HashMap<String, serde_json::Value>,
}
#[derive(Deserialize)]
struct RawStreamEvent {
#[serde(rename = "type")]
event_type: String,
#[serde(default)]
delta: Option<StreamDelta>,
#[serde(default)]
id: Option<String>,
#[serde(default)]
name: Option<String>,
#[serde(default)]
input: Option<HashMap<String, serde_json::Value>>,
#[serde(default)]
partial_json: Option<String>,
#[serde(default)]
input_tokens: Option<i32>,
#[serde(default)]
output_tokens: Option<i32>,
#[serde(default)]
cost_ticks: Option<i64>,
#[serde(default)]
message: Option<String>,
}
pin_project! {
pub struct ChatStream {
#[pin]
inner: Pin<Box<dyn Stream<Item = StreamEvent> + Send>>,
}
}
impl Stream for ChatStream {
type Item = StreamEvent;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx)
}
}
impl Client {
pub async fn chat(&self, req: &ChatRequest) -> Result<ChatResponse> {
let mut req = req.clone();
req.stream = Some(false);
let (mut resp, meta) = self.post_json::<ChatRequest, ChatResponse>("/qai/v1/chat", &req).await?;
resp.cost_ticks = meta.cost_ticks;
resp.request_id = meta.request_id;
if resp.model.is_empty() {
resp.model = meta.model;
}
Ok(resp)
}
pub async fn chat_stream(&self, req: &ChatRequest) -> Result<ChatStream> {
let mut req = req.clone();
req.stream = Some(true);
let (resp, _meta) = self.post_stream_raw("/qai/v1/chat", &req).await?;
let byte_stream = resp.bytes_stream();
let event_stream = sse_to_events(byte_stream);
Ok(ChatStream {
inner: Box::pin(event_stream),
})
}
}
fn sse_to_events<S>(byte_stream: S) -> impl Stream<Item = StreamEvent> + Send
where
S: Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
{
let pinned_stream = Box::pin(byte_stream);
let line_stream = futures_util::stream::unfold(
(pinned_stream, Vec::<u8>::new()),
|(mut stream, mut buffer)| async move {
use futures_util::StreamExt;
loop {
if let Some(newline_pos) = buffer.iter().position(|&b| b == b'\n') {
let mut line_bytes = buffer[..newline_pos].to_vec();
buffer = buffer[newline_pos + 1..].to_vec();
if line_bytes.last() == Some(&b'\r') {
line_bytes.pop();
}
let line = String::from_utf8_lossy(&line_bytes).into_owned();
return Some((line, (stream, buffer)));
}
match stream.next().await {
Some(Ok(chunk)) => {
buffer.extend_from_slice(&chunk);
}
Some(Err(_)) | None => {
if !buffer.is_empty() {
let remaining = String::from_utf8_lossy(&buffer).into_owned();
buffer.clear();
return Some((remaining, (stream, buffer)));
}
return None;
}
}
}
},
);
let pinned_lines = Box::pin(line_stream);
futures_util::stream::unfold(pinned_lines, |mut lines| async move {
use futures_util::StreamExt;
loop {
let line = lines.next().await?;
if !line.starts_with("data: ") {
continue;
}
let payload = &line["data: ".len()..];
if payload == "[DONE]" {
let ev = StreamEvent {
event_type: "done".to_string(),
delta: None,
tool_use: None,
tool_use_start: None,
tool_use_input_delta: None,
tool_use_complete: None,
usage: None,
error: None,
done: true,
};
return Some((ev, lines));
}
let raw: RawStreamEvent = match serde_json::from_str(payload) {
Ok(r) => r,
Err(e) => {
let ev = StreamEvent {
event_type: "error".to_string(),
delta: None,
tool_use: None,
tool_use_start: None,
tool_use_input_delta: None,
tool_use_complete: None,
usage: None,
error: Some(format!("parse SSE: {e}")),
done: false,
};
return Some((ev, lines));
}
};
let mut ev = StreamEvent {
event_type: raw.event_type.clone(),
delta: None,
tool_use: None,
tool_use_start: None,
tool_use_input_delta: None,
tool_use_complete: None,
usage: None,
error: None,
done: false,
};
match raw.event_type.as_str() {
"content_delta" | "thinking_delta" => {
ev.delta = raw.delta;
}
"tool_use" => {
ev.tool_use = Some(StreamToolUse {
id: raw.id.unwrap_or_default(),
name: raw.name.unwrap_or_default(),
input: raw.input.unwrap_or_default(),
});
}
"tool_use_start" => {
ev.tool_use_start = Some(StreamToolUseStart {
id: raw.id.unwrap_or_default(),
name: raw.name.unwrap_or_default(),
});
}
"tool_use_input_delta" => {
ev.tool_use_input_delta = Some(StreamToolUseInputDelta {
id: raw.id.unwrap_or_default(),
partial_json: raw.partial_json.unwrap_or_default(),
});
}
"tool_use_complete" => {
ev.tool_use_complete = Some(StreamToolUseComplete {
id: raw.id.unwrap_or_default(),
name: raw.name.unwrap_or_default(),
input: raw.input.unwrap_or_default(),
});
}
"usage" => {
ev.usage = Some(ChatUsage {
input_tokens: raw.input_tokens.unwrap_or(0),
output_tokens: raw.output_tokens.unwrap_or(0),
cost_ticks: raw.cost_ticks.unwrap_or(0),
});
}
"error" => {
ev.error = raw.message;
}
"heartbeat" => {}
_ => {}
}
return Some((ev, lines));
}
})
}