Skip to main content

outrig_cli/
llm.rs

1//! Resolve agent -> model -> provider; build Rig agent.
2
3use std::path::{Path, PathBuf};
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
6
7#[cfg(feature = "local-llm")]
8use futures_util::StreamExt;
9use rig::agent::{HookAction, PromptHook, ToolCallHookAction};
10#[cfg(feature = "local-llm")]
11use rig::agent::{MultiTurnStreamItem, StreamingError};
12use rig::completion::{CompletionModel, Message, Prompt};
13#[cfg(feature = "local-llm")]
14use rig::streaming::{StreamedAssistantContent, StreamingPrompt};
15use thiserror::Error;
16#[cfg(feature = "local-llm")]
17use tokio::io::{AsyncWrite, AsyncWriteExt};
18
19use crate::error::Result;
20use crate::rig_tool::McpToolAdapter;
21use outrig::config::{Config, DEFAULT_TOOL_CALL_MAX, LlmProvider, MistralrsDeviceSpec};
22
23/// Hard max on tool calls per turn. The hook below trips this; rig's own
24/// `max_turns` is set to the same value as a defense in depth, so whichever
25/// fires first surfaces a controllable message.
26pub const MAX_TOOL_CALLS: usize = DEFAULT_TOOL_CALL_MAX as usize;
27
28/// Default byte ceiling applied to each individual MCP tool result before it
29/// is handed to Rig and appended to model-visible chat history.
30pub const DEFAULT_TOOL_RESULT_MAX_BYTES: usize =
31    outrig::config::DEFAULT_TOOL_RESULT_MAX_BYTES as usize;
32
33/// Default per-request HTTP timeout for OpenAi-style providers when
34/// `request-timeout-secs` is unset (see `doc/reference/config.md`). Generous
35/// enough not to truncate long reasoning completions, and above typical proxy
36/// timeouts so a client-side timeout never races a still-in-flight server
37/// request.
38pub const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 600;
39
40pub mod retry;
41
42#[cfg(feature = "local-llm")]
43pub mod mistralrs;
44#[cfg(feature = "local-llm")]
45pub mod registry;
46
47#[cfg(feature = "local-llm")]
48pub use registry::LlmRegistry;
49
50/// Default preamble used when an agent leaves the field unset. Deliberately
51/// generic; agents that need anything specific spell it out themselves.
52const DEFAULT_PREAMBLE: &str =
53    "You are a careful assistant whose tools run inside a sandboxed container.";
54
55/// Failures that surface while walking `agents -> models -> providers` or
56/// constructing the Rig client. Wrapped into [`crate::error::OutrigError`]
57/// at the top level via `#[from]`.
58#[derive(Debug, Error)]
59pub enum LlmResolveError {
60    #[error(
61        "agent {name:?} is not defined; pass --agent <name> or set \
62         default-agent in config. Known agents: {known}"
63    )]
64    UnknownAgent { name: String, known: String },
65
66    #[error("agent {agent:?} omits 'model' and no default-model is set")]
67    AgentMissingModel { agent: String },
68
69    #[error("model {name:?} is not defined under [models.<name>]")]
70    UnknownModel { name: String },
71
72    #[error("provider {name:?} is not defined under [providers.<name>]")]
73    UnknownProvider { name: String },
74
75    #[error(
76        "mistralrs provider {name:?} requested but this build of outrig \
77         does not include the 'local-llm' feature; rebuild with \
78         --features local-llm to enable"
79    )]
80    MistralrsFeatureDisabled { name: String },
81
82    #[error(
83        "mistralrs model {model:?} has invalid device {device:?}; \
84         expected one of: cpu, cuda, cuda:N, metal"
85    )]
86    MistralrsDeviceInvalid { model: String, device: String },
87
88    #[error(
89        "mistralrs model {model:?} requested device {device:?} but this \
90         build of outrig does not include the '{feature}' feature; rebuild \
91         with --features {feature} to enable"
92    )]
93    MistralrsDeviceUnavailable {
94        model: String,
95        device: String,
96        feature: &'static str,
97    },
98
99    #[error(
100        "model {model:?} uses provider {provider:?}, which is not \
101         style=mistralrs; --device only applies to mistralrs models"
102    )]
103    MistralrsDeviceOverrideUnsupported { model: String, provider: String },
104
105    #[cfg(feature = "local-llm")]
106    #[error(
107        "mistralrs model {model:?}: requested context-length \
108         {requested} exceeds the model's maximum of {max}"
109    )]
110    MistralrsContextTooLong {
111        model: String,
112        requested: u32,
113        max: usize,
114    },
115
116    #[cfg(feature = "local-llm")]
117    #[error("mistralrs model {model:?}: failed to load model: {source}")]
118    MistralrsLoad {
119        model: String,
120        #[source]
121        source: anyhow::Error,
122    },
123
124    #[error("failed to build rig client: {0}")]
125    RigClientBuild(String),
126}
127
128/// Runtime-shaped provider view -- mirrors the config `LlmProvider` enum, but
129/// with the env-var-backed `ApiKeyRef` already resolved to a plain `String`
130/// for the OpenAi variant. Variants are kept in sync with `LlmProvider`'s.
131#[derive(Debug, Clone, PartialEq)]
132pub enum ResolvedProvider {
133    OpenAi {
134        base_url: String,
135        api_key: String,
136        request_timeout_secs: Option<u64>,
137    },
138    Mistralrs,
139}
140
141/// Weight-source spec for a mistralrs-backed model. Lifted off
142/// `[models.<name>]` at resolve time. Only one of `model_id` / `model_path`
143/// is meaningful in any given instance; validation enforces that, but
144/// `mistralrs::load` is also defensive.
145#[derive(Debug, Clone, PartialEq)]
146pub struct MistralrsWeights {
147    pub model_id: Option<String>,
148    pub model_path: Option<PathBuf>,
149    pub model_file: Option<Vec<String>>,
150    pub revision: Option<String>,
151    pub context_length: Option<u32>,
152    pub device: MistralrsDeviceSpec,
153}
154
155/// Fully-resolved view of one agent: every knob the agent loop needs to
156/// build a Rig client and run a turn.
157///
158/// For the `OpenAi` provider variant, the api-key is resolved from the env at
159/// construction time. The struct lives in the agent loop, not in session
160/// metadata, so it should never get serialized.
161#[derive(Debug, Clone, PartialEq)]
162pub struct ResolvedAgent {
163    pub agent_name: String,
164    pub model_name: String,
165    pub model_identifier: String,
166    pub provider_name: String,
167    pub provider: ResolvedProvider,
168    /// `Some` for mistralrs-style models, `None` for openai-style. Carries
169    /// the per-model weight spec that used to live on the provider config.
170    pub model_weights: Option<MistralrsWeights>,
171    pub preamble: String,
172    pub temperature: Option<f32>,
173    pub max_tokens: Option<u32>,
174    pub tool_call_max: usize,
175    pub tool_result_max_bytes: usize,
176    pub image: Option<String>,
177}
178
179/// Walk `cfg.agents -> models -> providers` to resolve every knob the agent
180/// loop needs. Bails with a descriptive error if a reference is dangling or
181/// the api-key env var is unset.
182///
183/// Each lookup is re-checked here -- the function does not assume
184/// `cfg.validate()` was called -- so errors carry the resolution context
185/// (which agent, which model) regardless.
186pub fn resolve_agent(cfg: &Config, agent_name: &str) -> Result<ResolvedAgent> {
187    resolve_agent_with_overrides(cfg, agent_name, None, None)
188}
189
190pub fn resolve_agent_with_device_override(
191    cfg: &Config,
192    agent_name: &str,
193    device_override: Option<MistralrsDeviceSpec>,
194) -> Result<ResolvedAgent> {
195    resolve_agent_with_overrides(cfg, agent_name, None, device_override)
196}
197
198pub fn resolve_agent_with_overrides(
199    cfg: &Config,
200    agent_name: &str,
201    model_override: Option<&str>,
202    device_override: Option<MistralrsDeviceSpec>,
203) -> Result<ResolvedAgent> {
204    let agent = cfg.agents.get(agent_name).ok_or_else(|| {
205        let known = if cfg.agents.is_empty() {
206            "(none)".to_string()
207        } else {
208            cfg.agents
209                .keys()
210                .map(String::as_str)
211                .collect::<Vec<_>>()
212                .join(", ")
213        };
214        LlmResolveError::UnknownAgent {
215            name: agent_name.to_string(),
216            known,
217        }
218    })?;
219
220    let model_name = model_override
221        .or(agent.model.as_deref())
222        .or(cfg.default_model.as_deref())
223        .ok_or_else(|| LlmResolveError::AgentMissingModel {
224            agent: agent_name.to_string(),
225        })?;
226
227    let model = cfg
228        .models
229        .get(model_name)
230        .ok_or_else(|| LlmResolveError::UnknownModel {
231            name: model_name.to_string(),
232        })?;
233
234    let provider =
235        cfg.providers
236            .get(&model.provider)
237            .ok_or_else(|| LlmResolveError::UnknownProvider {
238                name: model.provider.clone(),
239            })?;
240
241    let (resolved_provider, model_weights, model_identifier) = match provider {
242        LlmProvider::OpenAi {
243            base_url,
244            api_key,
245            request_timeout_secs,
246        } => {
247            if device_override.is_some() {
248                return Err(LlmResolveError::MistralrsDeviceOverrideUnsupported {
249                    model: model_name.to_string(),
250                    provider: model.provider.clone(),
251                }
252                .into());
253            }
254            let identifier = model
255                .identifier
256                .clone()
257                .unwrap_or_else(|| model_name.to_string());
258            (
259                ResolvedProvider::OpenAi {
260                    base_url: base_url.clone(),
261                    api_key: api_key.resolve()?,
262                    request_timeout_secs: *request_timeout_secs,
263                },
264                None,
265                identifier,
266            )
267        }
268        LlmProvider::Mistralrs => {
269            let device = match device_override {
270                Some(device) => validate_mistralrs_device(model_name, device)?,
271                None => parse_mistralrs_device(model_name, model.device.as_deref())?,
272            };
273            let weights = MistralrsWeights {
274                model_id: model.model_id.clone(),
275                model_path: model.model_path.clone(),
276                model_file: model.model_file.clone(),
277                revision: model.revision.clone(),
278                context_length: model.context_length,
279                device,
280            };
281            // For display: prefer the HF model-id, fall back to the GGUF
282            // basename, then the model name. mistralrs's own `load()`
283            // derives the same kind of identifier internally; this is for
284            // banner / error messaging only.
285            let identifier = weights
286                .model_id
287                .clone()
288                .or_else(|| {
289                    weights
290                        .model_path
291                        .as_deref()
292                        .and_then(|p| p.file_name())
293                        .and_then(|s| s.to_str())
294                        .map(str::to_string)
295                })
296                .unwrap_or_else(|| model_name.to_string());
297            (ResolvedProvider::Mistralrs, Some(weights), identifier)
298        }
299    };
300
301    Ok(ResolvedAgent {
302        agent_name: agent_name.to_string(),
303        model_name: model_name.to_string(),
304        model_identifier,
305        provider_name: model.provider.clone(),
306        provider: resolved_provider,
307        model_weights,
308        preamble: agent
309            .preamble
310            .clone()
311            .unwrap_or_else(|| DEFAULT_PREAMBLE.to_string()),
312        temperature: agent.temperature,
313        max_tokens: agent.max_tokens,
314        tool_call_max: agent
315            .tool_call_max
316            .or(cfg.tool_call_max)
317            .unwrap_or(DEFAULT_TOOL_CALL_MAX) as usize,
318        tool_result_max_bytes: agent
319            .tool_result_max
320            .or(cfg.tool_result_max)
321            .unwrap_or(outrig::config::DEFAULT_TOOL_RESULT_MAX_BYTES)
322            as usize,
323        image: agent.image.clone(),
324    })
325}
326
327fn parse_mistralrs_device(
328    model_name: &str,
329    device: Option<&str>,
330) -> std::result::Result<MistralrsDeviceSpec, LlmResolveError> {
331    let spec = match device {
332        Some(value) => value
333            .parse()
334            .map_err(|_| LlmResolveError::MistralrsDeviceInvalid {
335                model: model_name.to_string(),
336                device: value.to_string(),
337            })?,
338        None => MistralrsDeviceSpec::Cpu,
339    };
340    if !cfg!(feature = "local-llm") {
341        return Ok(spec);
342    }
343
344    validate_mistralrs_device(model_name, spec)
345}
346
347fn validate_mistralrs_device(
348    model_name: &str,
349    spec: MistralrsDeviceSpec,
350) -> std::result::Result<MistralrsDeviceSpec, LlmResolveError> {
351    if !cfg!(feature = "local-llm") {
352        return Ok(spec);
353    }
354
355    match spec {
356        MistralrsDeviceSpec::Cuda(_) if !cfg!(feature = "cuda") => {
357            Err(LlmResolveError::MistralrsDeviceUnavailable {
358                model: model_name.to_string(),
359                device: spec.to_string(),
360                feature: "cuda",
361            })
362        }
363        MistralrsDeviceSpec::Metal if !cfg!(feature = "metal") => {
364            Err(LlmResolveError::MistralrsDeviceUnavailable {
365                model: model_name.to_string(),
366                device: spec.to_string(),
367                feature: "metal",
368            })
369        }
370        _ => Ok(spec),
371    }
372}
373
374/// Runtime-dispatched Rig agent. The OpenAi-backed and mistralrs-backed
375/// `CompletionModel` impls produce concretely different `Agent<M>` types
376/// (Rig's trait carries associated types, so a single concrete `RigAgent`
377/// can't carry both). Callers (the agent loop) match on the variant.
378pub enum RigAgent {
379    OpenAi {
380        agent: rig::agent::Agent<retry::RetryingModel<rig::providers::openai::CompletionModel>>,
381        tool_call_max: usize,
382    },
383    #[cfg(feature = "local-llm")]
384    Mistralrs {
385        agent: rig::agent::Agent<crate::llm::mistralrs::MistralrsModel>,
386        tool_call_max: usize,
387    },
388}
389
390/// Build a Rig `Agent` ready to receive a turn. Preamble, sampling params,
391/// and the dynamic-tool list come from `resolved` plus the caller-supplied
392/// MCP-backed adapters. The `cache_root` argument is the directory into
393/// which the mistralrs HF-download path stages model files; it's ignored
394/// for OpenAi providers.
395///
396/// The function is async because the mistralrs arm has to load (and on
397/// first use, download) a multi-gigabyte model. The OpenAi arm is sync-in-
398/// async, free.
399pub async fn build_agent(
400    resolved: &ResolvedAgent,
401    tools: Vec<McpToolAdapter>,
402    cache_root: &Path,
403    #[cfg(feature = "local-llm")] registry: &LlmRegistry,
404) -> Result<RigAgent> {
405    #[cfg(not(feature = "local-llm"))]
406    let _ = cache_root;
407    match &resolved.provider {
408        ResolvedProvider::OpenAi {
409            base_url,
410            api_key,
411            request_timeout_secs,
412        } => {
413            use rig::client::CompletionClient;
414            use rig::providers::openai::CompletionsClient;
415
416            let timeout = std::time::Duration::from_secs(
417                request_timeout_secs.unwrap_or(DEFAULT_REQUEST_TIMEOUT_SECS),
418            );
419            let http = reqwest::Client::builder()
420                .timeout(timeout)
421                .build()
422                .map_err(|e| LlmResolveError::RigClientBuild(e.to_string()))?;
423
424            let client = CompletionsClient::builder()
425                .api_key(api_key.clone())
426                .base_url(base_url)
427                .http_client(http)
428                .build()
429                .map_err(|e| LlmResolveError::RigClientBuild(e.to_string()))?;
430            let model =
431                retry::RetryingModel::new(client.completion_model(&resolved.model_identifier));
432            Ok(RigAgent::OpenAi {
433                agent: finish_agent(model, resolved, tools),
434                tool_call_max: resolved.tool_call_max,
435            })
436        }
437        ResolvedProvider::Mistralrs => {
438            #[cfg(not(feature = "local-llm"))]
439            {
440                Err(LlmResolveError::MistralrsFeatureDisabled {
441                    name: resolved.provider_name.clone(),
442                }
443                .into())
444            }
445            #[cfg(feature = "local-llm")]
446            {
447                let weights = resolved.model_weights.as_ref().ok_or_else(|| {
448                    LlmResolveError::MistralrsLoad {
449                        model: resolved.model_name.clone(),
450                        source: anyhow::anyhow!(
451                            "internal: resolved mistralrs agent has no model_weights"
452                        ),
453                    }
454                })?;
455                let model_name = resolved.model_name.as_str();
456                let model_id = weights.model_id.as_deref();
457                let model_path = weights.model_path.as_deref();
458                let model_file = weights.model_file.as_deref();
459                let revision = weights.revision.as_deref();
460                let context_length = weights.context_length;
461                let device = weights.device;
462                let model = registry
463                    .get_or_init(model_name, || async move {
464                        crate::llm::mistralrs::load(
465                            model_name,
466                            model_id,
467                            model_path,
468                            model_file,
469                            revision,
470                            context_length,
471                            device,
472                            cache_root,
473                        )
474                        .await
475                    })
476                    .await?;
477                Ok(RigAgent::Mistralrs {
478                    agent: finish_agent((*model).clone(), resolved, tools),
479                    tool_call_max: resolved.tool_call_max,
480                })
481            }
482        }
483    }
484}
485
486impl RigAgent {
487    /// Run one user turn: prompt the model, drive the model->tool->model loop,
488    /// extend `history` with everything emitted (user prompt + tool turns +
489    /// final assistant reply), and return the assistant's text reply.
490    ///
491    /// The per-turn [`OutrigPromptHook`] prints `[outrig] tool call: ...` to
492    /// stderr for every tool invocation and terminates the loop after
493    /// the resolved tool-call max. If the hook terminates the loop, Rig
494    /// returns the partial chat history it had accumulated; outrig splices in
495    /// that new suffix so the user can send a follow-up prompt to continue.
496    pub async fn run_turn(&self, prompt: &str, history: &mut Vec<Message>) -> Result<String> {
497        match self {
498            RigAgent::OpenAi {
499                agent,
500                tool_call_max,
501            } => run_turn_inner(agent, prompt, history, *tool_call_max).await,
502            #[cfg(feature = "local-llm")]
503            RigAgent::Mistralrs {
504                agent,
505                tool_call_max,
506            } => run_turn_streaming_mistralrs(agent, prompt, history, *tool_call_max).await,
507        }
508    }
509}
510
511async fn run_turn_inner<M: CompletionModel + 'static>(
512    agent: &rig::agent::Agent<M>,
513    prompt: &str,
514    history: &mut Vec<Message>,
515    tool_call_max: usize,
516) -> Result<String> {
517    let hook = OutrigPromptHook::new(tool_call_max);
518    let result = agent
519        .prompt(prompt.to_string())
520        .with_history(history.clone())
521        .max_turns(tool_call_max)
522        .with_hook(hook)
523        .extended_details()
524        .await;
525
526    match result {
527        Ok(response) => {
528            let messages = response
529                .messages
530                .expect("rig populates messages on extended_details");
531            history.extend(messages);
532            Ok(response.output)
533        }
534        Err(other) => handle_prompt_error(other, history),
535    }
536}
537
538#[cfg(feature = "local-llm")]
539async fn run_turn_streaming_mistralrs(
540    agent: &rig::agent::Agent<crate::llm::mistralrs::MistralrsModel>,
541    prompt: &str,
542    history: &mut Vec<Message>,
543    tool_call_max: usize,
544) -> Result<String> {
545    let mut stdout = tokio::io::stdout();
546    run_turn_streaming_inner(agent, prompt, history, tool_call_max, &mut stdout).await
547}
548
549#[cfg(feature = "local-llm")]
550async fn run_turn_streaming_inner<M, W>(
551    agent: &rig::agent::Agent<M>,
552    prompt: &str,
553    history: &mut Vec<Message>,
554    tool_call_max: usize,
555    stdout: &mut W,
556) -> Result<String>
557where
558    M: CompletionModel + 'static,
559    W: AsyncWrite + Unpin,
560{
561    let hook = OutrigPromptHook::new(tool_call_max);
562    let mut stream = agent
563        .stream_prompt(prompt.to_string())
564        .with_history(history.clone())
565        .multi_turn(tool_call_max)
566        .with_hook(hook)
567        .await;
568
569    let mut streamed_reply = String::new();
570    let mut final_history: Option<Vec<Message>> = None;
571
572    while let Some(item) = stream.next().await {
573        match item {
574            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(text))) => {
575                stdout.write_all(text.text.as_bytes()).await?;
576                stdout.flush().await?;
577                streamed_reply.push_str(&text.text);
578            }
579            Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::ToolCall {
580                ..
581            })) => {
582                stdout.flush().await?;
583            }
584            Ok(MultiTurnStreamItem::FinalResponse(response)) => {
585                final_history = response.history().map(|messages| messages.to_vec());
586            }
587            Ok(_) => {}
588            Err(err) => {
589                return handle_streaming_error(err, history);
590            }
591        }
592    }
593
594    if let Some(messages) = final_history {
595        extend_history_with_new_suffix(history, messages);
596    }
597
598    if !streamed_reply.is_empty() && !streamed_reply.ends_with('\n') {
599        stdout.write_all(b"\n").await?;
600        stdout.flush().await?;
601    }
602
603    Ok(String::new())
604}
605
606#[cfg(feature = "local-llm")]
607fn handle_streaming_error(err: StreamingError, history: &mut Vec<Message>) -> Result<String> {
608    let prompt_error = match err {
609        StreamingError::Completion(err) => rig::completion::PromptError::CompletionError(err),
610        StreamingError::Prompt(err) => *err,
611        StreamingError::Tool(err) => rig::completion::PromptError::ToolError(err),
612    };
613    handle_prompt_error(prompt_error, history)
614}
615
616fn handle_prompt_error(
617    err: rig::completion::PromptError,
618    history: &mut Vec<Message>,
619) -> Result<String> {
620    match err {
621        rig::completion::PromptError::PromptCancelled {
622            reason,
623            chat_history,
624        } => {
625            eprintln!("[outrig] {reason}");
626            eprintln!(
627                "[outrig] partial history retained -- send another prompt \
628                 (e.g. \"continue\") to keep going, or \"/reset\" to drop it."
629            );
630            extend_history_with_new_suffix(history, chat_history);
631            Ok("(turn ended; tool-call max reached)".to_string())
632        }
633        rig::completion::PromptError::MaxTurnsError {
634            max_turns,
635            chat_history,
636            ..
637        } => {
638            eprintln!("[outrig] tool-call iteration max ({max_turns}) reached; ending turn");
639            eprintln!(
640                "[outrig] partial history retained -- send another prompt \
641                 (e.g. \"continue\") to keep going, or \"/reset\" to drop it."
642            );
643            extend_history_with_new_suffix(history, *chat_history);
644            Ok("(turn ended; tool-call max reached)".to_string())
645        }
646        other => Err(other.into()),
647    }
648}
649
650fn extend_history_with_new_suffix(history: &mut Vec<Message>, returned: Vec<Message>) {
651    let existing_len = history.len();
652    if returned.len() >= existing_len && returned[..existing_len] == history[..] {
653        history.extend(returned.into_iter().skip(existing_len));
654    } else {
655        history.extend(returned);
656    }
657}
658
659/// Per-request hook that traces every tool call to stderr and stops the agent
660/// loop after `max` calls. Cloned by rig per request; shared atomics keep a
661/// single turn's calls counting against the same max.
662#[derive(Clone)]
663pub struct OutrigPromptHook {
664    counter: Arc<AtomicUsize>,
665    cap_reached: Arc<AtomicBool>,
666    max: usize,
667}
668
669impl OutrigPromptHook {
670    pub fn new(max: usize) -> Self {
671        Self {
672            counter: Arc::new(AtomicUsize::new(0)),
673            cap_reached: Arc::new(AtomicBool::new(false)),
674            max,
675        }
676    }
677}
678
679impl<M: CompletionModel> PromptHook<M> for OutrigPromptHook {
680    async fn on_completion_call(&self, _prompt: &Message, _history: &[Message]) -> HookAction {
681        if self.cap_reached.load(Ordering::SeqCst) {
682            return HookAction::terminate(format!(
683                "tool-call iteration max ({}) reached; ending turn",
684                self.max
685            ));
686        }
687        HookAction::cont()
688    }
689
690    async fn on_tool_call(
691        &self,
692        tool_name: &str,
693        _tool_call_id: Option<String>,
694        _internal_call_id: &str,
695        args: &str,
696    ) -> ToolCallHookAction {
697        let n = self.counter.fetch_add(1, Ordering::SeqCst) + 1;
698        if n > self.max {
699            self.cap_reached.store(true, Ordering::SeqCst);
700            return ToolCallHookAction::skip(format!(
701                "[outrig] tool call not executed: per-turn tool-call max ({}) \
702                 was reached before this call could run. The user may continue \
703                 with a fresh max; repeat the tool call if still needed.",
704                self.max
705            ));
706        }
707        eprintln!("[outrig] tool call: {tool_name}({args})");
708        ToolCallHookAction::cont()
709    }
710}
711
712fn finish_agent<M: rig::completion::CompletionModel + 'static>(
713    model: M,
714    resolved: &ResolvedAgent,
715    tools: Vec<McpToolAdapter>,
716) -> rig::agent::Agent<M> {
717    use rig::agent::AgentBuilder;
718    use rig::tool::ToolDyn;
719
720    let mut builder = AgentBuilder::new(model).preamble(&resolved.preamble);
721    if let Some(temperature) = resolved.temperature {
722        builder = builder.temperature(temperature as f64);
723    }
724    if let Some(max_tokens) = resolved.max_tokens {
725        builder = builder.max_tokens(max_tokens as u64);
726    }
727    let boxed: Vec<Box<dyn ToolDyn>> = tools
728        .into_iter()
729        .map(|t| Box::new(t) as Box<dyn ToolDyn>)
730        .collect();
731    builder.tools(boxed).build()
732}
733
734#[cfg(test)]
735mod tests {
736    use super::*;
737    #[cfg(feature = "local-llm")]
738    use rig::completion::{CompletionError, CompletionRequest, CompletionResponse, Usage};
739    #[cfg(feature = "local-llm")]
740    use rig::streaming::{RawStreamingChoice, StreamingCompletionResponse};
741
742    #[test]
743    fn cancelled_history_retains_only_new_suffix_when_full_history_returned() {
744        let original = vec![Message::user("first"), Message::assistant("done")];
745        let mut history = original.clone();
746        let mut returned = original;
747        returned.push(Message::user("second"));
748        returned.push(Message::assistant("partial"));
749
750        extend_history_with_new_suffix(&mut history, returned);
751
752        assert_eq!(
753            history,
754            vec![
755                Message::user("first"),
756                Message::assistant("done"),
757                Message::user("second"),
758                Message::assistant("partial"),
759            ],
760        );
761    }
762
763    #[test]
764    fn cancelled_history_appends_when_returned_history_is_only_partial() {
765        let mut history = vec![Message::user("first")];
766        let returned = vec![Message::assistant("partial")];
767
768        extend_history_with_new_suffix(&mut history, returned);
769
770        assert_eq!(
771            history,
772            vec![Message::user("first"), Message::assistant("partial")],
773        );
774    }
775
776    #[cfg(feature = "local-llm")]
777    #[derive(Clone)]
778    struct ScriptedStreamingModel {
779        chunks: Arc<Vec<RawStreamingChoice<()>>>,
780    }
781
782    #[cfg(feature = "local-llm")]
783    impl ScriptedStreamingModel {
784        fn new(chunks: Vec<RawStreamingChoice<()>>) -> Self {
785            Self {
786                chunks: Arc::new(chunks),
787            }
788        }
789    }
790
791    #[cfg(feature = "local-llm")]
792    impl CompletionModel for ScriptedStreamingModel {
793        type Response = ();
794        type StreamingResponse = ();
795        type Client = ();
796
797        fn make(_client: &Self::Client, _model: impl Into<String>) -> Self {
798            Self::new(Vec::new())
799        }
800
801        async fn completion(
802            &self,
803            _request: CompletionRequest,
804        ) -> std::result::Result<CompletionResponse<Self::Response>, CompletionError> {
805            Ok(CompletionResponse {
806                choice: rig::OneOrMany::one(rig::completion::AssistantContent::text("")),
807                usage: Usage::new(),
808                raw_response: (),
809                message_id: None,
810            })
811        }
812
813        async fn stream(
814            &self,
815            _request: CompletionRequest,
816        ) -> std::result::Result<
817            StreamingCompletionResponse<Self::StreamingResponse>,
818            CompletionError,
819        > {
820            let chunks = self.chunks.clone();
821            let stream = async_stream::try_stream! {
822                for chunk in chunks.iter().cloned() {
823                    yield chunk;
824                }
825            };
826            Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
827        }
828    }
829
830    #[cfg(feature = "local-llm")]
831    #[tokio::test]
832    async fn streaming_turn_writes_chunks_once_and_retains_history() {
833        let model = ScriptedStreamingModel::new(vec![
834            RawStreamingChoice::Message("hello ".to_string()),
835            RawStreamingChoice::Message("world".to_string()),
836        ]);
837        let agent = rig::agent::AgentBuilder::new(model).build();
838        let mut history = Vec::new();
839        let mut stdout = Vec::new();
840
841        let reply = run_turn_streaming_inner(&agent, "hi", &mut history, 50, &mut stdout)
842            .await
843            .expect("streaming turn succeeds");
844
845        assert_eq!(reply, "");
846        assert_eq!(
847            String::from_utf8(stdout).expect("stdout utf-8"),
848            "hello world\n"
849        );
850        assert_eq!(
851            history,
852            vec![Message::user("hi"), Message::assistant("hello world")],
853        );
854    }
855}