Skip to main content

quantum_sdk/
agent.rs

1use std::collections::HashMap;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures_util::Stream;
6use pin_project_lite::pin_project;
7use serde::{Deserialize, Serialize};
8
9use crate::client::Client;
10use crate::error::Result;
11use crate::session::ContextConfig;
12
13// ---------------------------------------------------------------------------
14// Agent
15// ---------------------------------------------------------------------------
16
17/// Describes a worker agent in a multi-agent run.
18#[derive(Debug, Clone, Serialize, Deserialize, Default)]
19pub struct AgentWorker {
20    /// Worker name.
21    pub name: String,
22
23    /// Model ID for this worker.
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub model: Option<String>,
26
27    /// Worker tier (e.g. "fast", "thinking").
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub tier: Option<String>,
30
31    /// Description of this worker's role.
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub description: Option<String>,
34}
35
36/// Request body for an agent run.
37#[derive(Debug, Clone, Serialize, Default)]
38pub struct AgentRequest {
39    /// Session identifier for continuity across runs.
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub session_id: Option<String>,
42
43    /// The task for the agent to accomplish.
44    pub task: String,
45
46    /// Model for the conductor agent.
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub conductor_model: Option<String>,
49
50    /// Worker agents available to the conductor.
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub workers: Option<Vec<AgentWorker>>,
53
54    /// Maximum number of steps before stopping.
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub max_steps: Option<i32>,
57
58    /// System prompt for the conductor.
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub system_prompt: Option<String>,
61}
62
63// ---------------------------------------------------------------------------
64// Mission
65// ---------------------------------------------------------------------------
66
67/// Describes a named worker for a mission (map keyed by name).
68#[derive(Debug, Clone, Serialize, Deserialize, Default)]
69pub struct MissionWorker {
70    /// Model ID for this worker.
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub model: Option<String>,
73
74    /// Worker tier.
75    #[serde(skip_serializing_if = "Option::is_none")]
76    pub tier: Option<String>,
77
78    /// Description of this worker's purpose.
79    #[serde(skip_serializing_if = "Option::is_none")]
80    pub description: Option<String>,
81}
82
83/// Request body for a mission run.
84#[derive(Debug, Clone, Serialize, Default)]
85pub struct MissionRequest {
86    /// The high-level goal for the mission.
87    pub goal: String,
88
89    /// Execution strategy hint.
90    #[serde(skip_serializing_if = "Option::is_none")]
91    pub strategy: Option<String>,
92
93    /// Model for the conductor.
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub conductor_model: Option<String>,
96
97    /// Named workers (key = worker name).
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub workers: Option<HashMap<String, MissionWorker>>,
100
101    /// Maximum number of steps.
102    #[serde(skip_serializing_if = "Option::is_none")]
103    pub max_steps: Option<i32>,
104
105    /// System prompt for the conductor.
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub system_prompt: Option<String>,
108
109    /// Session identifier for continuity.
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub session_id: Option<String>,
112
113    /// Whether to auto-plan before execution.
114    #[serde(skip_serializing_if = "Option::is_none")]
115    pub auto_plan: Option<bool>,
116
117    /// Context management configuration.
118    #[serde(skip_serializing_if = "Option::is_none")]
119    pub context_config: Option<ContextConfig>,
120}
121
122// ---------------------------------------------------------------------------
123// SSE Stream
124// ---------------------------------------------------------------------------
125
126/// A single event from an agent or mission SSE stream.
127#[derive(Debug, Clone, Deserialize)]
128pub struct AgentStreamEvent {
129    /// Event type (e.g. "step", "thought", "tool_call", "tool_result", "message", "error", "done").
130    #[serde(rename = "type", default)]
131    pub event_type: String,
132
133    /// Raw JSON payload for caller to interpret.
134    #[serde(flatten)]
135    pub data: HashMap<String, serde_json::Value>,
136}
137
138pin_project! {
139    /// An async stream of [`AgentStreamEvent`]s from an agent or mission SSE response.
140    pub struct AgentStream {
141        #[pin]
142        inner: Pin<Box<dyn Stream<Item = AgentStreamEvent> + Send>>,
143    }
144}
145
146impl Stream for AgentStream {
147    type Item = AgentStreamEvent;
148
149    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
150        self.project().inner.poll_next(cx)
151    }
152}
153
154/// Converts a byte stream into a stream of parsed [`AgentStreamEvent`]s.
155fn sse_to_agent_events<S>(byte_stream: S) -> impl Stream<Item = AgentStreamEvent> + Send
156where
157    S: Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
158{
159    let pinned_stream = Box::pin(byte_stream);
160
161    let line_stream = futures_util::stream::unfold(
162        (pinned_stream, String::new()),
163        |(mut stream, mut buffer)| async move {
164            use futures_util::StreamExt;
165            loop {
166                if let Some(newline_pos) = buffer.find('\n') {
167                    let line = buffer[..newline_pos].trim_end_matches('\r').to_string();
168                    buffer = buffer[newline_pos + 1..].to_string();
169                    return Some((line, (stream, buffer)));
170                }
171
172                match stream.next().await {
173                    Some(Ok(chunk)) => {
174                        buffer.push_str(&String::from_utf8_lossy(&chunk));
175                    }
176                    Some(Err(_)) | None => {
177                        if !buffer.is_empty() {
178                            let remaining = std::mem::take(&mut buffer);
179                            return Some((remaining, (stream, buffer)));
180                        }
181                        return None;
182                    }
183                }
184            }
185        },
186    );
187
188    let pinned_lines = Box::pin(line_stream);
189    futures_util::stream::unfold(pinned_lines, |mut lines| async move {
190        use futures_util::StreamExt;
191        loop {
192            let line = lines.next().await?;
193
194            if !line.starts_with("data: ") {
195                continue;
196            }
197            let payload = &line["data: ".len()..];
198
199            if payload == "[DONE]" {
200                let ev = AgentStreamEvent {
201                    event_type: "done".to_string(),
202                    data: HashMap::new(),
203                };
204                return Some((ev, lines));
205            }
206
207            match serde_json::from_str::<AgentStreamEvent>(payload) {
208                Ok(ev) => return Some((ev, lines)),
209                Err(e) => {
210                    let mut data = HashMap::new();
211                    data.insert(
212                        "error".to_string(),
213                        serde_json::Value::String(format!("parse SSE: {e}")),
214                    );
215                    let ev = AgentStreamEvent {
216                        event_type: "error".to_string(),
217                        data,
218                    };
219                    return Some((ev, lines));
220                }
221            }
222        }
223    })
224}
225
226impl Client {
227    /// Starts an agent run and returns an SSE event stream.
228    ///
229    /// The agent orchestrates one or more worker models to accomplish the task,
230    /// streaming progress events as it works.
231    pub async fn agent_run(&self, req: &AgentRequest) -> Result<AgentStream> {
232        let (resp, _meta) = self.post_stream_raw("/qai/v1/agent", req).await?;
233        let byte_stream = resp.bytes_stream();
234        let event_stream = sse_to_agent_events(byte_stream);
235        Ok(AgentStream {
236            inner: Box::pin(event_stream),
237        })
238    }
239
240    /// Starts a mission run and returns an SSE event stream.
241    ///
242    /// Missions are higher-level than agents -- they can auto-plan, assign
243    /// named workers, and manage context across multiple steps.
244    pub async fn mission_run(&self, req: &MissionRequest) -> Result<AgentStream> {
245        let (resp, _meta) = self.post_stream_raw("/qai/v1/missions", req).await?;
246        let byte_stream = resp.bytes_stream();
247        let event_stream = sse_to_agent_events(byte_stream);
248        Ok(AgentStream {
249            inner: Box::pin(event_stream),
250        })
251    }
252}