Skip to main content

modelrelay/
runs.rs

1use std::{
2    collections::{HashMap, VecDeque},
3    pin::Pin,
4    sync::{
5        atomic::{AtomicBool, Ordering},
6        Arc,
7    },
8    task::{Context, Poll},
9};
10
11use reqwest::Method;
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use uuid::Uuid;
15
16use crate::{
17    client::ClientInner,
18    core::consume_ndjson_buffer,
19    errors::{Error, Result, TransportError, ValidationError},
20    generated::{RunsPendingToolsResponse, ToolCallId, ToolName},
21    http::{request_id_from_headers, validate_ndjson_content_type, HeaderList},
22    workflow::{
23        NodeId, NodeResultV0, PlanHash, RequestId, RunCostSummaryV0, RunEventV0, RunId, RunStatusV0,
24    },
25    workflow_intent::WorkflowIntentSpec,
26};
27
28#[cfg(feature = "streaming")]
29use crate::ndjson::classify_reqwest_error;
30#[cfg(feature = "streaming")]
31use futures_core::Stream;
32#[cfg(feature = "streaming")]
33use futures_util::{stream, StreamExt};
34
35#[derive(Clone)]
36pub struct RunsClient {
37    pub(crate) inner: Arc<ClientInner>,
38}
39
40#[derive(Debug, Clone, Serialize)]
41struct RunsCreateRequest {
42    spec: WorkflowIntentSpec,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    input: Option<HashMap<String, Value>>,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    model_override: Option<String>,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    model_overrides: Option<RunsModelOverrides>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    session_id: Option<Uuid>,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    stream: Option<bool>,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    options: Option<RunsCreateOptionsV0>,
55}
56
57#[derive(Debug, Clone, Serialize)]
58struct RunsCreateFromPlanRequest {
59    plan_hash: PlanHash,
60    #[serde(skip_serializing_if = "Option::is_none")]
61    input: Option<HashMap<String, Value>>,
62    #[serde(skip_serializing_if = "Option::is_none")]
63    model_override: Option<String>,
64    #[serde(skip_serializing_if = "Option::is_none")]
65    model_overrides: Option<RunsModelOverrides>,
66    #[serde(skip_serializing_if = "Option::is_none")]
67    session_id: Option<Uuid>,
68    #[serde(skip_serializing_if = "Option::is_none")]
69    stream: Option<bool>,
70    #[serde(skip_serializing_if = "Option::is_none")]
71    options: Option<RunsCreateOptionsV0>,
72}
73
74#[derive(Debug, Clone, Serialize)]
75struct RunsCreateOptionsV0 {
76    idempotency_key: String,
77}
78
79#[derive(Debug, Clone, Default)]
80pub struct RunsCreateOptions {
81    pub session_id: Option<Uuid>,
82    pub input: Option<HashMap<String, Value>>,
83    pub model_override: Option<String>,
84    pub model_overrides: Option<RunsModelOverrides>,
85    pub stream: Option<bool>,
86    pub idempotency_key: Option<String>,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct RunsModelOverrides {
91    #[serde(skip_serializing_if = "Option::is_none")]
92    pub nodes: Option<HashMap<String, String>>,
93    #[serde(skip_serializing_if = "Option::is_none")]
94    pub fanout_subnodes: Option<Vec<RunsFanoutSubnodeOverride>>,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct RunsFanoutSubnodeOverride {
99    pub parent_id: String,
100    pub subnode_id: String,
101    pub model: String,
102}
103
104#[derive(Debug, Clone, Deserialize)]
105pub struct RunsCreateResponse {
106    pub run_id: RunId,
107    pub status: RunStatusV0,
108    pub plan_hash: PlanHash,
109}
110
111#[derive(Debug, Clone, Deserialize)]
112pub struct RunsGetResponse {
113    pub run_id: RunId,
114    pub status: RunStatusV0,
115    pub plan_hash: PlanHash,
116    pub cost_summary: RunCostSummaryV0,
117    #[serde(default)]
118    pub nodes: Vec<NodeResultV0>,
119    #[serde(default)]
120    pub outputs: HashMap<String, Value>,
121}
122
123#[derive(Debug, Clone, Serialize)]
124pub struct RunsToolResultsRequest {
125    pub node_id: NodeId,
126    pub step: u64,
127    pub request_id: RequestId,
128    pub results: Vec<RunsToolResultItemV0>,
129}
130
131#[derive(Debug, Clone, Serialize)]
132pub struct RunsToolCallV0 {
133    pub id: ToolCallId,
134    pub name: ToolName,
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub arguments: Option<String>,
137}
138
139#[derive(Debug, Clone, Serialize)]
140pub struct RunsToolResultItemV0 {
141    pub tool_call: RunsToolCallV0,
142    pub output: String,
143}
144
145#[derive(Debug, Clone, Deserialize)]
146pub struct RunsToolResultsResponse {
147    pub accepted: u64,
148    pub status: RunStatusV0,
149}
150
151// RunsPendingToolsResponse, RunsPendingToolsNodeV0, RunsPendingToolCallV0 are
152// generated from OpenAPI spec - imported from crate::generated above.
153
154impl RunsClient {
155    pub async fn create(&self, spec: WorkflowIntentSpec) -> Result<RunsCreateResponse> {
156        self.create_with_options(spec, RunsCreateOptions::default())
157            .await
158    }
159
160    pub async fn create_with_options(
161        &self,
162        spec: WorkflowIntentSpec,
163        options: RunsCreateOptions,
164    ) -> Result<RunsCreateResponse> {
165        self.inner.ensure_auth()?;
166        if options.session_id.is_some_and(|id| id.is_nil()) {
167            return Err(Error::Validation(
168                ValidationError::new("session_id is required").with_field("session_id"),
169            ));
170        }
171
172        let options_payload = options.idempotency_key.as_ref().and_then(|key| {
173            let trimmed = key.trim();
174            if trimmed.is_empty() {
175                None
176            } else {
177                Some(RunsCreateOptionsV0 {
178                    idempotency_key: trimmed.to_string(),
179                })
180            }
181        });
182
183        let model_override = options
184            .model_override
185            .as_ref()
186            .and_then(|val| match val.trim() {
187                "" => None,
188                trimmed => Some(trimmed.to_string()),
189            });
190        let model_overrides = options.model_overrides.clone();
191
192        let mut builder = self.inner.request(Method::POST, "/runs")?;
193        builder = builder.json(&RunsCreateRequest {
194            spec,
195            input: options.input,
196            model_override,
197            model_overrides,
198            session_id: options.session_id,
199            stream: options.stream,
200            options: options_payload,
201        });
202        builder = self.inner.with_headers(
203            builder,
204            None,
205            &HeaderList::default(),
206            Some("application/json"),
207        )?;
208        builder = self.inner.with_timeout(builder, None, true);
209        let ctx = self.inner.make_context(&Method::POST, "/runs", None, None);
210        self.inner
211            .execute_json(builder, Method::POST, None, ctx)
212            .await
213    }
214
215    pub async fn create_with_session(
216        &self,
217        spec: WorkflowIntentSpec,
218        session_id: Uuid,
219    ) -> Result<RunsCreateResponse> {
220        self.create_with_options(
221            spec,
222            RunsCreateOptions {
223                session_id: Some(session_id),
224                ..RunsCreateOptions::default()
225            },
226        )
227        .await
228    }
229
230    /// Creates a workflow run using a precompiled plan hash.
231    ///
232    /// Use [`crate::WorkflowsClient::compile`] to compile a workflow spec and obtain a `plan_hash`,
233    /// then use this method to start runs without re-compiling each time.
234    /// This is useful for workflows that are run repeatedly with the same structure
235    /// but different inputs.
236    ///
237    /// The plan_hash must have been compiled in the current server session;
238    /// if the server has restarted since compilation, the plan will not be found
239    /// and you'll need to recompile.
240    pub async fn create_from_plan(&self, plan_hash: PlanHash) -> Result<RunsCreateResponse> {
241        self.create_from_plan_with_options(plan_hash, RunsCreateOptions::default())
242            .await
243    }
244
245    /// Creates a workflow run using a precompiled plan hash with options.
246    ///
247    /// See [`Self::create_from_plan`] for details on plan_hash usage.
248    pub async fn create_from_plan_with_options(
249        &self,
250        plan_hash: PlanHash,
251        options: RunsCreateOptions,
252    ) -> Result<RunsCreateResponse> {
253        self.inner.ensure_auth()?;
254        // PlanHash derefs to String, auto-deref handles is_empty()
255        if plan_hash.is_empty() {
256            return Err(Error::Validation(
257                ValidationError::new("plan_hash is required").with_field("plan_hash"),
258            ));
259        }
260        if options.session_id.is_some_and(|id| id.is_nil()) {
261            return Err(Error::Validation(
262                ValidationError::new("session_id is required").with_field("session_id"),
263            ));
264        }
265
266        let options_payload = options.idempotency_key.as_ref().and_then(|key| {
267            let trimmed = key.trim();
268            if trimmed.is_empty() {
269                None
270            } else {
271                Some(RunsCreateOptionsV0 {
272                    idempotency_key: trimmed.to_string(),
273                })
274            }
275        });
276
277        let model_override = options
278            .model_override
279            .as_ref()
280            .and_then(|val| match val.trim() {
281                "" => None,
282                trimmed => Some(trimmed.to_string()),
283            });
284        let model_overrides = options.model_overrides.clone();
285
286        let mut builder = self.inner.request(Method::POST, "/runs")?;
287        builder = builder.json(&RunsCreateFromPlanRequest {
288            plan_hash,
289            input: options.input,
290            model_override,
291            model_overrides,
292            session_id: options.session_id,
293            stream: options.stream,
294            options: options_payload,
295        });
296        builder = self.inner.with_headers(
297            builder,
298            None,
299            &HeaderList::default(),
300            Some("application/json"),
301        )?;
302        builder = self.inner.with_timeout(builder, None, true);
303        let ctx = self.inner.make_context(&Method::POST, "/runs", None, None);
304        self.inner
305            .execute_json(builder, Method::POST, None, ctx)
306            .await
307    }
308
309    pub async fn get(&self, run_id: RunId) -> Result<RunsGetResponse> {
310        self.inner.ensure_auth()?;
311        if run_id.0.is_nil() {
312            return Err(Error::Validation(
313                ValidationError::new("run_id is required").with_field("run_id"),
314            ));
315        }
316        let path = format!("/runs/{}", run_id);
317        let builder = self.inner.request(Method::GET, &path)?;
318        let builder = self.inner.with_headers(
319            builder,
320            None,
321            &HeaderList::default(),
322            Some("application/json"),
323        )?;
324        let builder = self.inner.with_timeout(builder, None, true);
325        let ctx = self.inner.make_context(&Method::GET, &path, None, None);
326        self.inner
327            .execute_json(builder, Method::GET, None, ctx)
328            .await
329    }
330
331    pub async fn pending_tools(&self, run_id: RunId) -> Result<RunsPendingToolsResponse> {
332        self.inner.ensure_auth()?;
333        if run_id.0.is_nil() {
334            return Err(Error::Validation(
335                ValidationError::new("run_id is required").with_field("run_id"),
336            ));
337        }
338        let path = format!("/runs/{}/pending-tools", run_id);
339        let builder = self.inner.request(Method::GET, &path)?;
340        let builder = self.inner.with_headers(
341            builder,
342            None,
343            &HeaderList::default(),
344            Some("application/json"),
345        )?;
346        let builder = self.inner.with_timeout(builder, None, true);
347        let ctx = self.inner.make_context(&Method::GET, &path, None, None);
348        self.inner
349            .execute_json(builder, Method::GET, None, ctx)
350            .await
351    }
352
353    pub async fn submit_tool_results(
354        &self,
355        run_id: RunId,
356        req: RunsToolResultsRequest,
357    ) -> Result<RunsToolResultsResponse> {
358        self.inner.ensure_auth()?;
359        if run_id.0.is_nil() {
360            return Err(Error::Validation(
361                ValidationError::new("run_id is required").with_field("run_id"),
362            ));
363        }
364        let path = format!("/runs/{}/tool-results", run_id);
365        let mut builder = self.inner.request(Method::POST, &path)?;
366        builder = builder.json(&req);
367        builder = self.inner.with_headers(
368            builder,
369            None,
370            &HeaderList::default(),
371            Some("application/json"),
372        )?;
373        builder = self.inner.with_timeout(builder, None, true);
374        let ctx = self.inner.make_context(&Method::POST, &path, None, None);
375        self.inner
376            .execute_json(builder, Method::POST, None, ctx)
377            .await
378    }
379
380    #[cfg(feature = "streaming")]
381    pub async fn stream_events(
382        &self,
383        run_id: RunId,
384        after_seq: Option<i64>,
385        limit: Option<i64>,
386    ) -> Result<RunEventStreamHandle> {
387        self.inner.ensure_auth()?;
388        if run_id.0.is_nil() {
389            return Err(Error::Validation(
390                ValidationError::new("run_id is required").with_field("run_id"),
391            ));
392        }
393        let mut path = format!("/runs/{}/events", run_id);
394        let mut q = vec![];
395        if let Some(seq) = after_seq {
396            if seq > 0 {
397                q.push(("after_seq", seq.to_string()));
398            }
399        }
400        if let Some(lim) = limit {
401            if lim > 0 {
402                q.push(("limit", lim.to_string()));
403            }
404        }
405        if !q.is_empty() {
406            path.push('?');
407            path.push_str(
408                &q.into_iter()
409                    .map(|(k, v)| format!("{k}={v}"))
410                    .collect::<Vec<_>>()
411                    .join("&"),
412            );
413        }
414
415        let builder = self.inner.request(Method::GET, &path)?;
416        let builder = self.inner.with_headers(
417            builder,
418            None,
419            &HeaderList::default(),
420            Some("application/x-ndjson"),
421        )?;
422
423        let retry = self.inner.retry.clone();
424        let ctx = self.inner.make_context(&Method::GET, &path, None, None);
425        let resp = self
426            .inner
427            .send_with_retry(builder, Method::GET, retry, ctx.clone())
428            .await?;
429
430        validate_ndjson_content_type(resp.headers(), resp.status().as_u16())?;
431
432        let request_id = request_id_from_headers(resp.headers());
433        Ok(RunEventStreamHandle::new(resp, request_id))
434    }
435}
436
437#[cfg(feature = "streaming")]
438pub struct RunEventStreamHandle {
439    request_id: Option<String>,
440    stream: Pin<Box<dyn Stream<Item = Result<RunEventV0>> + Send>>,
441    cancelled: Arc<AtomicBool>,
442}
443
444#[cfg(feature = "streaming")]
445impl RunEventStreamHandle {
446    fn new(response: reqwest::Response, request_id: Option<String>) -> Self {
447        let cancelled = Arc::new(AtomicBool::new(false));
448        let stream = build_run_events_stream(response, cancelled.clone());
449        Self {
450            request_id,
451            stream: Box::pin(stream),
452            cancelled,
453        }
454    }
455
456    pub fn request_id(&self) -> Option<&str> {
457        self.request_id.as_deref()
458    }
459
460    pub fn cancel(&self) {
461        self.cancelled.store(true, Ordering::SeqCst);
462    }
463}
464
465#[cfg(feature = "streaming")]
466impl Drop for RunEventStreamHandle {
467    fn drop(&mut self) {
468        self.cancel();
469    }
470}
471
472#[cfg(feature = "streaming")]
473impl Stream for RunEventStreamHandle {
474    type Item = Result<RunEventV0>;
475
476    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
477        let stream = unsafe { self.map_unchecked_mut(|s| &mut s.stream) };
478        stream.poll_next(cx)
479    }
480}
481
482/// State for run event stream processing.
483///
484/// Simpler than NdjsonStreamState since runs don't need timeouts or telemetry.
485#[cfg(feature = "streaming")]
486struct RunEventStreamState<B> {
487    body: B,
488    buffer: String,
489    cancelled: Arc<AtomicBool>,
490    pending: VecDeque<RunEventV0>,
491}
492
493#[cfg(feature = "streaming")]
494impl<B> RunEventStreamState<B> {
495    fn new(body: B, cancelled: Arc<AtomicBool>) -> Self {
496        Self {
497            body,
498            buffer: String::new(),
499            cancelled,
500            pending: VecDeque::new(),
501        }
502    }
503
504    fn is_cancelled(&self) -> bool {
505        self.cancelled.load(Ordering::SeqCst)
506    }
507}
508
509#[cfg(feature = "streaming")]
510fn build_run_events_stream(
511    response: reqwest::Response,
512    cancelled: Arc<AtomicBool>,
513) -> impl Stream<Item = Result<RunEventV0>> + Send {
514    let body = response.bytes_stream();
515    let state = RunEventStreamState::new(body, cancelled);
516
517    stream::unfold(state, |mut state| async move {
518        loop {
519            // Check cancellation
520            if state.is_cancelled() {
521                return None;
522            }
523
524            // Return pending events
525            if let Some(ev) = state.pending.pop_front() {
526                return Some((Ok(ev), state));
527            }
528
529            match state.body.next().await {
530                Some(Ok(chunk)) => {
531                    // Parse UTF-8
532                    let text = match String::from_utf8(chunk.to_vec()) {
533                        Ok(s) => s,
534                        Err(e) => {
535                            let err = Error::StreamProtocol {
536                                message: format!("invalid UTF-8 in stream: {}", e),
537                                raw_data: None,
538                            };
539                            return Some((Err(err), state));
540                        }
541                    };
542
543                    // Parse NDJSON lines
544                    state.buffer.push_str(&text);
545                    let (events, remainder) = consume_ndjson_buffer(&state.buffer);
546                    state.buffer = remainder;
547
548                    // Process events
549                    for raw in events {
550                        match parse_run_event(&raw.data) {
551                            Ok(ev) => state.pending.push_back(ev),
552                            Err(err) => return Some((Err(err), state)),
553                        }
554                    }
555                    continue;
556                }
557                Some(Err(err)) => {
558                    let error = Error::Transport(TransportError {
559                        kind: classify_reqwest_error(&err),
560                        message: err.to_string(),
561                        source: Some(err),
562                        retries: None,
563                    });
564                    return Some((Err(error), state));
565                }
566                None => {
567                    // Stream ended, process remaining buffer
568                    let (events, _) = consume_ndjson_buffer(&state.buffer);
569                    state.buffer.clear();
570
571                    for raw in events {
572                        match parse_run_event(&raw.data) {
573                            Ok(ev) => state.pending.push_back(ev),
574                            Err(err) => return Some((Err(err), state)),
575                        }
576                    }
577
578                    if let Some(ev) = state.pending.pop_front() {
579                        return Some((Ok(ev), state));
580                    }
581                    return None;
582                }
583            }
584        }
585    })
586}
587
588/// Parses and validates a run event from raw JSON data.
589/// Extracted for SRP: separates parsing/validation from stream management.
590#[cfg(feature = "streaming")]
591fn parse_run_event(raw_data: &str) -> Result<RunEventV0> {
592    let ev: RunEventV0 = serde_json::from_str(raw_data).map_err(|err| Error::StreamProtocol {
593        message: format!("failed to parse run event: {err}"),
594        raw_data: Some(raw_data.to_string()),
595    })?;
596    ev.validate().map_err(|err| Error::StreamProtocol {
597        message: format!("invalid run event: {err}"),
598        raw_data: Some(raw_data.to_string()),
599    })?;
600    Ok(ev)
601}