use std::collections::HashMap;
use std::pin::Pin;
use std::task::{Context, Poll};
use futures_util::Stream;
use pin_project_lite::pin_project;
use serde::{Deserialize, Serialize};
use crate::client::Client;
use crate::error::Result;
use crate::session::ContextConfig;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AgentWorker {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tier: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Default)]
pub struct AgentRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
pub task: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub conductor_model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub workers: Option<Vec<AgentWorker>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_steps: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub context_config: Option<crate::session::ContextConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct MissionWorker {
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tier: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub escalate_to: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_retries: Option<i32>,
}
#[derive(Debug, Clone, Serialize, Default)]
pub struct MissionRequest {
pub goal: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub strategy: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub conductor_model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub conductor_tier: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub workers: Option<HashMap<String, MissionWorker>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_steps: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub auto_plan: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub context_config: Option<ContextConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub worker_model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub deployment_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub build_command: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub workspace_path: Option<String>,
}
pub type AgentWorkerConfig = AgentWorker;
pub type MissionWorkerConfig = MissionWorker;
#[derive(Debug, Clone, Deserialize)]
pub struct AgentStreamEvent {
#[serde(rename = "type", default)]
pub event_type: String,
#[serde(flatten)]
pub data: HashMap<String, serde_json::Value>,
}
pub type AgentEvent = AgentStreamEvent;
pub type MissionEvent = AgentStreamEvent;
pin_project! {
pub struct AgentStream {
#[pin]
inner: Pin<Box<dyn Stream<Item = AgentStreamEvent> + Send>>,
}
}
impl Stream for AgentStream {
type Item = AgentStreamEvent;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx)
}
}
fn sse_to_agent_events<S>(byte_stream: S) -> impl Stream<Item = AgentStreamEvent> + Send
where
S: Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
{
let pinned_stream = Box::pin(byte_stream);
let line_stream = futures_util::stream::unfold(
(pinned_stream, String::new()),
|(mut stream, mut buffer)| async move {
use futures_util::StreamExt;
loop {
if let Some(newline_pos) = buffer.find('\n') {
let line = buffer[..newline_pos].trim_end_matches('\r').to_string();
buffer = buffer[newline_pos + 1..].to_string();
return Some((line, (stream, buffer)));
}
match stream.next().await {
Some(Ok(chunk)) => {
buffer.push_str(&String::from_utf8_lossy(&chunk));
}
Some(Err(_)) | None => {
if !buffer.is_empty() {
let remaining = std::mem::take(&mut buffer);
return Some((remaining, (stream, buffer)));
}
return None;
}
}
}
},
);
let pinned_lines = Box::pin(line_stream);
futures_util::stream::unfold(pinned_lines, |mut lines| async move {
use futures_util::StreamExt;
loop {
let line = lines.next().await?;
if !line.starts_with("data: ") {
continue;
}
let payload = &line["data: ".len()..];
if payload == "[DONE]" {
let ev = AgentStreamEvent {
event_type: "done".to_string(),
data: HashMap::new(),
};
return Some((ev, lines));
}
match serde_json::from_str::<AgentStreamEvent>(payload) {
Ok(ev) => return Some((ev, lines)),
Err(e) => {
let mut data = HashMap::new();
data.insert(
"error".to_string(),
serde_json::Value::String(format!("parse SSE: {e}")),
);
let ev = AgentStreamEvent {
event_type: "error".to_string(),
data,
};
return Some((ev, lines));
}
}
}
})
}
impl Client {
pub async fn agent_run(&self, req: &AgentRequest) -> Result<AgentStream> {
let (resp, _meta) = self.post_stream_raw("/qai/v1/agent", req).await?;
let byte_stream = resp.bytes_stream();
let event_stream = sse_to_agent_events(byte_stream);
Ok(AgentStream {
inner: Box::pin(event_stream),
})
}
pub async fn mission_run(&self, req: &MissionRequest) -> Result<AgentStream> {
let (resp, _meta) = self.post_stream_raw("/qai/v1/missions", req).await?;
let byte_stream = resp.bytes_stream();
let event_stream = sse_to_agent_events(byte_stream);
Ok(AgentStream {
inner: Box::pin(event_stream),
})
}
}