synwire-core 0.1.0

Core traits and types for the Synwire AI framework
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
//! Agent runner and execution loop.
//!
//! `Runner` drives the agent turn loop:
//! session lookup → middleware chain → model invocation → tool dispatch →
//! directive execution → event emission → usage tracking.
//!
//! It enforces `max_turns` and `max_budget` limits, handles model errors with
//! configurable retry / fallback, and supports graceful and force stop.

use std::sync::Arc;

use serde_json::Value;
use tokio::sync::{Mutex, mpsc};

use crate::agents::agent_node::Agent;
use crate::agents::error::AgentError;
use crate::agents::streaming::{AgentEvent, TerminationReason};
use crate::agents::usage::Usage;

// ---------------------------------------------------------------------------
// Stop signal
// ---------------------------------------------------------------------------

/// Kind of stop requested from outside the runner.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StopKind {
    /// Drain in-flight tool calls, then stop cleanly.
    Graceful,
    /// Cancel immediately without draining.
    Force,
}

// ---------------------------------------------------------------------------
// RunErrorAction
// ---------------------------------------------------------------------------

/// Specifies what the runner should do when an error occurs.
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum RunErrorAction {
    /// Retry the current request (up to a configurable limit).
    Retry,
    /// Continue to the next turn ignoring this error.
    Continue,
    /// Abort the run immediately.
    Abort(String),
    /// Switch to a different model and retry.
    SwitchModel(String),
}

// ---------------------------------------------------------------------------
// RunnerConfig
// ---------------------------------------------------------------------------

/// Configuration for a single runner execution.
#[derive(Debug, Clone)]
pub struct RunnerConfig {
    /// Override the agent's model for this run.
    pub model_override: Option<String>,
    /// Session ID to resume (None = new session).
    pub session_id: Option<String>,
    /// Maximum number of retries per model error.
    pub max_retries: u32,
}

impl Default for RunnerConfig {
    fn default() -> Self {
        Self {
            model_override: None,
            session_id: None,
            max_retries: 3,
        }
    }
}

// ---------------------------------------------------------------------------
// Runner
// ---------------------------------------------------------------------------

/// Drives the agent execution loop.
///
/// The runner is stateless between runs; all per-run state is held in the
/// channel and local variables inside `run`.
#[derive(Debug)]
pub struct Runner<O: serde::Serialize + Send + Sync + 'static = ()> {
    agent: Arc<Agent<O>>,
    /// Current model — may be changed via `set_model`.
    current_model: Mutex<String>,
    /// Stop signal sender.
    stop_tx: Mutex<Option<mpsc::Sender<StopKind>>>,
}

impl<O: serde::Serialize + Send + Sync + 'static> Runner<O> {
    /// Create a runner wrapping the given agent.
    #[must_use]
    pub fn new(agent: Agent<O>) -> Self {
        let model = agent.model_name().to_string();
        Self {
            agent: Arc::new(agent),
            current_model: Mutex::new(model),
            stop_tx: Mutex::new(None),
        }
    }

    /// Dynamically switch the model for subsequent turns, preserving
    /// conversation history.
    pub async fn set_model(&self, model: impl Into<String>) {
        let mut guard = self.current_model.lock().await;
        *guard = model.into();
        tracing::info!(model = %*guard, "Runner: model switched");
    }

    /// Send a graceful stop signal.  The runner will finish any in-flight
    /// tool call, then emit `TurnComplete { reason: Stopped }`.
    pub async fn stop_graceful(&self) {
        if let Some(tx) = self.stop_tx.lock().await.as_ref() {
            let _ = tx.send(StopKind::Graceful).await;
        }
    }

    /// Send a force stop signal.  The runner cancels immediately and emits
    /// `TurnComplete { reason: Aborted }`.
    pub async fn stop_force(&self) {
        if let Some(tx) = self.stop_tx.lock().await.as_ref() {
            let _ = tx.send(StopKind::Force).await;
        }
    }

    /// Run the agent with the given input, yielding events over a channel.
    ///
    /// # Event stream
    /// Events are sent on the returned receiver.  The stream ends when the
    /// receiver is closed (after a `TurnComplete` or `Error` event).
    ///
    /// # Errors
    /// Returns `AgentError` if setup fails before the event stream starts.
    pub async fn run(
        &self,
        input: Value,
        config: RunnerConfig,
    ) -> Result<mpsc::Receiver<AgentEvent>, AgentError> {
        let (event_tx, event_rx) = mpsc::channel::<AgentEvent>(128);
        let (stop_tx, stop_rx) = mpsc::channel::<StopKind>(1);

        // Store stop sender so callers can signal stop.
        *self.stop_tx.lock().await = Some(stop_tx);

        let agent = Arc::clone(&self.agent);
        let model = self.current_model.lock().await.clone();

        let _handle = tokio::spawn(async move {
            run_loop(agent, input, config, model, event_tx, stop_rx).await;
        });

        Ok(event_rx)
    }
}

// ---------------------------------------------------------------------------
// Core loop (spawned task)
// ---------------------------------------------------------------------------

#[allow(clippy::too_many_lines)]
async fn run_loop<O: serde::Serialize + Send + Sync + 'static>(
    agent: Arc<Agent<O>>,
    input: Value,
    config: RunnerConfig,
    initial_model: String,
    event_tx: mpsc::Sender<AgentEvent>,
    mut stop_rx: mpsc::Receiver<StopKind>,
) {
    let max_turns = agent.max_turn_count();
    let max_budget = agent.budget_limit();
    let max_retries = config.max_retries;

    let mut current_model = config.model_override.unwrap_or(initial_model);
    let mut turn: u32 = 0;
    let mut cumulative_cost: f64 = 0.0;
    let mut messages: Vec<Value> = Vec::new();
    let mut retry_count: u32 = 0;

    // Seed conversation with the user's input.
    messages.push(serde_json::json!({ "role": "user", "content": input }));

    loop {
        // Check for stop signal (non-blocking poll).
        match stop_rx.try_recv() {
            Ok(StopKind::Graceful) => {
                emit(
                    &event_tx,
                    AgentEvent::TurnComplete {
                        reason: TerminationReason::Stopped,
                    },
                )
                .await;
                return;
            }
            Ok(StopKind::Force) => {
                emit(
                    &event_tx,
                    AgentEvent::TurnComplete {
                        reason: TerminationReason::Aborted,
                    },
                )
                .await;
                return;
            }
            Err(_) => {}
        }

        // Enforce max_turns.
        if let Some(limit) = max_turns
            && turn >= limit
        {
            tracing::debug!(turn, limit, "max_turns reached");
            emit(
                &event_tx,
                AgentEvent::TurnComplete {
                    reason: TerminationReason::MaxTurnsExceeded,
                },
            )
            .await;
            return;
        }

        // Enforce max_budget.
        if let Some(budget) = max_budget
            && cumulative_cost > budget
        {
            tracing::debug!(cumulative_cost, budget, "budget exceeded");
            emit(
                &event_tx,
                AgentEvent::TurnComplete {
                    reason: TerminationReason::BudgetExceeded,
                },
            )
            .await;
            return;
        }

        turn += 1;

        // --- Simulated model invocation ---
        // In production this would call the LLM backend.  The runner provides
        // the scaffolding; actual model calls are injected by provider crates.
        let model_result = invoke_model(&current_model, &messages);

        match model_result {
            Ok(response) => {
                retry_count = 0;

                // Accumulate synthetic usage.
                let usage = Usage {
                    input_tokens: response.input_tokens,
                    output_tokens: response.output_tokens,
                    ..Usage::default()
                };
                cumulative_cost += response.estimated_cost;

                // Emit usage update.
                emit(&event_tx, AgentEvent::UsageUpdate { usage }).await;

                // Emit text delta if present.
                if let Some(text) = response.text {
                    emit(&event_tx, AgentEvent::TextDelta { content: text }).await;
                }

                // Check if model signalled completion.
                if response.done {
                    emit(
                        &event_tx,
                        AgentEvent::TurnComplete {
                            reason: TerminationReason::Complete,
                        },
                    )
                    .await;
                    return;
                }

                // Append assistant message and continue loop.
                messages.push(serde_json::json!({ "role": "assistant", "content": response.raw }));
            }

            Err(err) => {
                let action = dispatch_model_error(
                    &err,
                    retry_count,
                    max_retries,
                    agent.fallback_model_name(),
                );

                match action {
                    RunErrorAction::Retry => {
                        retry_count += 1;
                        tracing::warn!(attempt = retry_count, model = %current_model, "Retrying after model error");
                        turn -= 1; // don't count against max_turns
                    }
                    RunErrorAction::SwitchModel(fallback) => {
                        tracing::warn!(
                            from = %current_model,
                            to = %fallback,
                            "Switching to fallback model"
                        );
                        current_model = fallback;
                        retry_count = 0;
                        turn -= 1;
                    }
                    RunErrorAction::Continue => {
                        tracing::warn!(%err, "Model error ignored — continuing");
                    }
                    RunErrorAction::Abort(msg) => {
                        emit(&event_tx, AgentEvent::Error { message: msg }).await;
                        return;
                    }
                }
            }
        }
    }
}

// ---------------------------------------------------------------------------
// Error dispatch
// ---------------------------------------------------------------------------

fn dispatch_model_error(
    err: &AgentError,
    retry_count: u32,
    max_retries: u32,
    fallback_model: Option<&str>,
) -> RunErrorAction {
    match err {
        AgentError::Model(model_err) => {
            if !model_err.is_retryable() {
                return RunErrorAction::Abort(err.to_string());
            }
            if retry_count < max_retries {
                // Try fallback on second retry if available.
                if retry_count > 0
                    && let Some(fb) = fallback_model
                {
                    return RunErrorAction::SwitchModel(fb.to_string());
                }
                RunErrorAction::Retry
            } else if let Some(fb) = fallback_model {
                RunErrorAction::SwitchModel(fb.to_string())
            } else {
                RunErrorAction::Abort(format!("Max retries ({max_retries}) exceeded: {err}"))
            }
        }
        AgentError::Panic(msg) => {
            tracing::error!(%msg, "Agent panicked");
            RunErrorAction::Abort(format!("Agent panicked: {msg}"))
        }
        _ => RunErrorAction::Abort(err.to_string()),
    }
}

// ---------------------------------------------------------------------------
// Stub model invocation (replaced by provider crates at runtime)
// ---------------------------------------------------------------------------

struct ModelResponse {
    text: Option<String>,
    raw: Value,
    input_tokens: u64,
    output_tokens: u64,
    estimated_cost: f64,
    done: bool,
}

/// Placeholder model invocation.  Real implementations are injected by
/// provider crates (e.g. `synwire-llm-openai`) via the `AgentNode::run`
/// delegation path.
#[allow(clippy::unnecessary_wraps)]
fn invoke_model(model: &str, messages: &[Value]) -> Result<ModelResponse, AgentError> {
    tracing::debug!(%model, message_count = messages.len(), "invoke_model (stub)");
    // Stub: immediately complete with empty response.
    Ok(ModelResponse {
        text: None,
        raw: Value::Null,
        input_tokens: 0,
        output_tokens: 0,
        estimated_cost: 0.0,
        done: true,
    })
}

// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------

async fn emit(tx: &mpsc::Sender<AgentEvent>, event: AgentEvent) {
    // Ignore send errors — receiver may have been dropped.
    let _ = tx.send(event).await;
}

#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
    use super::*;
    use crate::agents::agent_node::Agent;

    #[tokio::test]
    async fn test_runner_completes() {
        let agent: Agent = Agent::new("test", "stub-model");
        let runner = Runner::new(agent);
        let mut rx = runner
            .run(serde_json::json!("Hello"), RunnerConfig::default())
            .await
            .unwrap();

        let mut got_complete = false;
        while let Some(event) = rx.recv().await {
            if let AgentEvent::TurnComplete { reason } = event {
                assert_eq!(reason, TerminationReason::Complete);
                got_complete = true;
            }
        }
        assert!(got_complete, "expected TurnComplete event");
    }

    #[tokio::test]
    async fn test_runner_max_turns() {
        // The stub model never sets done=true on its own in subsequent turns,
        // but does set done=true immediately.  Adjust by giving 0 max_turns.
        let agent: Agent = Agent::new("test", "stub-model").max_turns(0);
        let runner = Runner::new(agent);
        let mut rx = runner
            .run(serde_json::json!("Hello"), RunnerConfig::default())
            .await
            .unwrap();

        let mut got_max_turns = false;
        while let Some(event) = rx.recv().await {
            if let AgentEvent::TurnComplete { reason } = event {
                // With max_turns=0 the first check fires immediately.
                if reason == TerminationReason::MaxTurnsExceeded {
                    got_max_turns = true;
                }
            }
        }
        assert!(got_max_turns, "expected MaxTurnsExceeded");
    }

    #[tokio::test]
    async fn test_runner_graceful_stop() {
        let agent: Agent = Agent::new("test", "stub-model");
        let runner = Arc::new(Runner::new(agent));
        let runner2 = Arc::clone(&runner);

        let mut rx = runner
            .run(serde_json::json!("Hello"), RunnerConfig::default())
            .await
            .unwrap();

        // Stop before any events are processed (races, but tests the wiring).
        runner2.stop_graceful().await;

        let mut saw_stop_or_complete = false;
        while let Some(event) = rx.recv().await {
            if let AgentEvent::TurnComplete { reason } = event
                && matches!(
                    reason,
                    TerminationReason::Stopped | TerminationReason::Complete
                )
            {
                saw_stop_or_complete = true;
            }
        }
        assert!(saw_stop_or_complete);
    }
}