1use std::path::{Path, PathBuf};
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
6
7#[cfg(feature = "local-llm")]
8use futures_util::StreamExt;
9use rig::agent::{HookAction, PromptHook, ToolCallHookAction};
10#[cfg(feature = "local-llm")]
11use rig::agent::{MultiTurnStreamItem, StreamingError};
12use rig::completion::{CompletionModel, Message, Prompt};
13#[cfg(feature = "local-llm")]
14use rig::streaming::{StreamedAssistantContent, StreamingPrompt};
15use thiserror::Error;
16#[cfg(feature = "local-llm")]
17use tokio::io::{AsyncWrite, AsyncWriteExt};
18
19use crate::error::Result;
20use crate::rig_tool::McpToolAdapter;
21use outrig::config::{Config, DEFAULT_TOOL_CALL_MAX, LlmProvider, MistralrsDeviceSpec};
22
23pub const MAX_TOOL_CALLS: usize = DEFAULT_TOOL_CALL_MAX as usize;
27
28pub const DEFAULT_TOOL_RESULT_MAX_BYTES: usize =
31 outrig::config::DEFAULT_TOOL_RESULT_MAX_BYTES as usize;
32
33pub const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 600;
39
40pub mod retry;
41
42#[cfg(feature = "local-llm")]
43pub mod mistralrs;
44#[cfg(feature = "local-llm")]
45pub mod registry;
46
47#[cfg(feature = "local-llm")]
48pub use registry::LlmRegistry;
49
50const DEFAULT_PREAMBLE: &str =
53 "You are a careful assistant whose tools run inside a sandboxed container.";
54
55#[derive(Debug, Error)]
59pub enum LlmResolveError {
60 #[error(
61 "agent {name:?} is not defined; pass --agent <name> or set \
62 default-agent in config. Known agents: {known}"
63 )]
64 UnknownAgent { name: String, known: String },
65
66 #[error("agent {agent:?} omits 'model' and no default-model is set")]
67 AgentMissingModel { agent: String },
68
69 #[error("model {name:?} is not defined under [models.<name>]")]
70 UnknownModel { name: String },
71
72 #[error("provider {name:?} is not defined under [providers.<name>]")]
73 UnknownProvider { name: String },
74
75 #[error(
76 "mistralrs provider {name:?} requested but this build of outrig \
77 does not include the 'local-llm' feature; rebuild with \
78 --features local-llm to enable"
79 )]
80 MistralrsFeatureDisabled { name: String },
81
82 #[error(
83 "mistralrs model {model:?} has invalid device {device:?}; \
84 expected one of: cpu, cuda, cuda:N, metal"
85 )]
86 MistralrsDeviceInvalid { model: String, device: String },
87
88 #[error(
89 "mistralrs model {model:?} requested device {device:?} but this \
90 build of outrig does not include the '{feature}' feature; rebuild \
91 with --features {feature} to enable"
92 )]
93 MistralrsDeviceUnavailable {
94 model: String,
95 device: String,
96 feature: &'static str,
97 },
98
99 #[error(
100 "model {model:?} uses provider {provider:?}, which is not \
101 style=mistralrs; --device only applies to mistralrs models"
102 )]
103 MistralrsDeviceOverrideUnsupported { model: String, provider: String },
104
105 #[cfg(feature = "local-llm")]
106 #[error(
107 "mistralrs model {model:?}: requested context-length \
108 {requested} exceeds the model's maximum of {max}"
109 )]
110 MistralrsContextTooLong {
111 model: String,
112 requested: u32,
113 max: usize,
114 },
115
116 #[cfg(feature = "local-llm")]
117 #[error("mistralrs model {model:?}: failed to load model: {source}")]
118 MistralrsLoad {
119 model: String,
120 #[source]
121 source: anyhow::Error,
122 },
123
124 #[error("failed to build rig client: {0}")]
125 RigClientBuild(String),
126}
127
128#[derive(Debug, Clone, PartialEq)]
132pub enum ResolvedProvider {
133 OpenAi {
134 base_url: String,
135 api_key: String,
136 request_timeout_secs: Option<u64>,
137 },
138 Mistralrs,
139}
140
141#[derive(Debug, Clone, PartialEq)]
146pub struct MistralrsWeights {
147 pub model_id: Option<String>,
148 pub model_path: Option<PathBuf>,
149 pub model_file: Option<Vec<String>>,
150 pub revision: Option<String>,
151 pub context_length: Option<u32>,
152 pub device: MistralrsDeviceSpec,
153}
154
155#[derive(Debug, Clone, PartialEq)]
162pub struct ResolvedAgent {
163 pub agent_name: String,
164 pub model_name: String,
165 pub model_identifier: String,
166 pub provider_name: String,
167 pub provider: ResolvedProvider,
168 pub model_weights: Option<MistralrsWeights>,
171 pub preamble: String,
172 pub temperature: Option<f32>,
173 pub max_tokens: Option<u32>,
174 pub tool_call_max: usize,
175 pub tool_result_max_bytes: usize,
176 pub image: Option<String>,
177}
178
179pub fn resolve_agent(cfg: &Config, agent_name: &str) -> Result<ResolvedAgent> {
187 resolve_agent_with_overrides(cfg, agent_name, None, None)
188}
189
190pub fn resolve_agent_with_device_override(
191 cfg: &Config,
192 agent_name: &str,
193 device_override: Option<MistralrsDeviceSpec>,
194) -> Result<ResolvedAgent> {
195 resolve_agent_with_overrides(cfg, agent_name, None, device_override)
196}
197
198pub fn resolve_agent_with_overrides(
199 cfg: &Config,
200 agent_name: &str,
201 model_override: Option<&str>,
202 device_override: Option<MistralrsDeviceSpec>,
203) -> Result<ResolvedAgent> {
204 let agent = cfg.agents.get(agent_name).ok_or_else(|| {
205 let known = if cfg.agents.is_empty() {
206 "(none)".to_string()
207 } else {
208 cfg.agents
209 .keys()
210 .map(String::as_str)
211 .collect::<Vec<_>>()
212 .join(", ")
213 };
214 LlmResolveError::UnknownAgent {
215 name: agent_name.to_string(),
216 known,
217 }
218 })?;
219
220 let model_name = model_override
221 .or(agent.model.as_deref())
222 .or(cfg.default_model.as_deref())
223 .ok_or_else(|| LlmResolveError::AgentMissingModel {
224 agent: agent_name.to_string(),
225 })?;
226
227 let model = cfg
228 .models
229 .get(model_name)
230 .ok_or_else(|| LlmResolveError::UnknownModel {
231 name: model_name.to_string(),
232 })?;
233
234 let provider =
235 cfg.providers
236 .get(&model.provider)
237 .ok_or_else(|| LlmResolveError::UnknownProvider {
238 name: model.provider.clone(),
239 })?;
240
241 let (resolved_provider, model_weights, model_identifier) = match provider {
242 LlmProvider::OpenAi {
243 base_url,
244 api_key,
245 request_timeout_secs,
246 } => {
247 if device_override.is_some() {
248 return Err(LlmResolveError::MistralrsDeviceOverrideUnsupported {
249 model: model_name.to_string(),
250 provider: model.provider.clone(),
251 }
252 .into());
253 }
254 let identifier = model
255 .identifier
256 .clone()
257 .unwrap_or_else(|| model_name.to_string());
258 (
259 ResolvedProvider::OpenAi {
260 base_url: base_url.clone(),
261 api_key: api_key.resolve()?,
262 request_timeout_secs: *request_timeout_secs,
263 },
264 None,
265 identifier,
266 )
267 }
268 LlmProvider::Mistralrs => {
269 let device = match device_override {
270 Some(device) => validate_mistralrs_device(model_name, device)?,
271 None => parse_mistralrs_device(model_name, model.device.as_deref())?,
272 };
273 let weights = MistralrsWeights {
274 model_id: model.model_id.clone(),
275 model_path: model.model_path.clone(),
276 model_file: model.model_file.clone(),
277 revision: model.revision.clone(),
278 context_length: model.context_length,
279 device,
280 };
281 let identifier = weights
286 .model_id
287 .clone()
288 .or_else(|| {
289 weights
290 .model_path
291 .as_deref()
292 .and_then(|p| p.file_name())
293 .and_then(|s| s.to_str())
294 .map(str::to_string)
295 })
296 .unwrap_or_else(|| model_name.to_string());
297 (ResolvedProvider::Mistralrs, Some(weights), identifier)
298 }
299 };
300
301 Ok(ResolvedAgent {
302 agent_name: agent_name.to_string(),
303 model_name: model_name.to_string(),
304 model_identifier,
305 provider_name: model.provider.clone(),
306 provider: resolved_provider,
307 model_weights,
308 preamble: agent
309 .preamble
310 .clone()
311 .unwrap_or_else(|| DEFAULT_PREAMBLE.to_string()),
312 temperature: agent.temperature,
313 max_tokens: agent.max_tokens,
314 tool_call_max: agent
315 .tool_call_max
316 .or(cfg.tool_call_max)
317 .unwrap_or(DEFAULT_TOOL_CALL_MAX) as usize,
318 tool_result_max_bytes: agent
319 .tool_result_max
320 .or(cfg.tool_result_max)
321 .unwrap_or(outrig::config::DEFAULT_TOOL_RESULT_MAX_BYTES)
322 as usize,
323 image: agent.image.clone(),
324 })
325}
326
327fn parse_mistralrs_device(
328 model_name: &str,
329 device: Option<&str>,
330) -> std::result::Result<MistralrsDeviceSpec, LlmResolveError> {
331 let spec = match device {
332 Some(value) => value
333 .parse()
334 .map_err(|_| LlmResolveError::MistralrsDeviceInvalid {
335 model: model_name.to_string(),
336 device: value.to_string(),
337 })?,
338 None => MistralrsDeviceSpec::Cpu,
339 };
340 if !cfg!(feature = "local-llm") {
341 return Ok(spec);
342 }
343
344 validate_mistralrs_device(model_name, spec)
345}
346
347fn validate_mistralrs_device(
348 model_name: &str,
349 spec: MistralrsDeviceSpec,
350) -> std::result::Result<MistralrsDeviceSpec, LlmResolveError> {
351 if !cfg!(feature = "local-llm") {
352 return Ok(spec);
353 }
354
355 match spec {
356 MistralrsDeviceSpec::Cuda(_) if !cfg!(feature = "cuda") => {
357 Err(LlmResolveError::MistralrsDeviceUnavailable {
358 model: model_name.to_string(),
359 device: spec.to_string(),
360 feature: "cuda",
361 })
362 }
363 MistralrsDeviceSpec::Metal if !cfg!(feature = "metal") => {
364 Err(LlmResolveError::MistralrsDeviceUnavailable {
365 model: model_name.to_string(),
366 device: spec.to_string(),
367 feature: "metal",
368 })
369 }
370 _ => Ok(spec),
371 }
372}
373
374pub enum RigAgent {
379 OpenAi {
380 agent: rig::agent::Agent<retry::RetryingModel<rig::providers::openai::CompletionModel>>,
381 tool_call_max: usize,
382 },
383 #[cfg(feature = "local-llm")]
384 Mistralrs {
385 agent: rig::agent::Agent<crate::llm::mistralrs::MistralrsModel>,
386 tool_call_max: usize,
387 },
388}
389
390pub async fn build_agent(
400 resolved: &ResolvedAgent,
401 tools: Vec<McpToolAdapter>,
402 cache_root: &Path,
403 #[cfg(feature = "local-llm")] registry: &LlmRegistry,
404) -> Result<RigAgent> {
405 #[cfg(not(feature = "local-llm"))]
406 let _ = cache_root;
407 match &resolved.provider {
408 ResolvedProvider::OpenAi {
409 base_url,
410 api_key,
411 request_timeout_secs,
412 } => {
413 use rig::client::CompletionClient;
414 use rig::providers::openai::CompletionsClient;
415
416 let timeout = std::time::Duration::from_secs(
417 request_timeout_secs.unwrap_or(DEFAULT_REQUEST_TIMEOUT_SECS),
418 );
419 let http = reqwest::Client::builder()
420 .timeout(timeout)
421 .build()
422 .map_err(|e| LlmResolveError::RigClientBuild(e.to_string()))?;
423
424 let client = CompletionsClient::builder()
425 .api_key(api_key.clone())
426 .base_url(base_url)
427 .http_client(http)
428 .build()
429 .map_err(|e| LlmResolveError::RigClientBuild(e.to_string()))?;
430 let model =
431 retry::RetryingModel::new(client.completion_model(&resolved.model_identifier));
432 Ok(RigAgent::OpenAi {
433 agent: finish_agent(model, resolved, tools),
434 tool_call_max: resolved.tool_call_max,
435 })
436 }
437 ResolvedProvider::Mistralrs => {
438 #[cfg(not(feature = "local-llm"))]
439 {
440 Err(LlmResolveError::MistralrsFeatureDisabled {
441 name: resolved.provider_name.clone(),
442 }
443 .into())
444 }
445 #[cfg(feature = "local-llm")]
446 {
447 let weights = resolved.model_weights.as_ref().ok_or_else(|| {
448 LlmResolveError::MistralrsLoad {
449 model: resolved.model_name.clone(),
450 source: anyhow::anyhow!(
451 "internal: resolved mistralrs agent has no model_weights"
452 ),
453 }
454 })?;
455 let model_name = resolved.model_name.as_str();
456 let model_id = weights.model_id.as_deref();
457 let model_path = weights.model_path.as_deref();
458 let model_file = weights.model_file.as_deref();
459 let revision = weights.revision.as_deref();
460 let context_length = weights.context_length;
461 let device = weights.device;
462 let model = registry
463 .get_or_init(model_name, || async move {
464 crate::llm::mistralrs::load(
465 model_name,
466 model_id,
467 model_path,
468 model_file,
469 revision,
470 context_length,
471 device,
472 cache_root,
473 )
474 .await
475 })
476 .await?;
477 Ok(RigAgent::Mistralrs {
478 agent: finish_agent((*model).clone(), resolved, tools),
479 tool_call_max: resolved.tool_call_max,
480 })
481 }
482 }
483 }
484}
485
486impl RigAgent {
487 pub async fn run_turn(&self, prompt: &str, history: &mut Vec<Message>) -> Result<String> {
497 match self {
498 RigAgent::OpenAi {
499 agent,
500 tool_call_max,
501 } => run_turn_inner(agent, prompt, history, *tool_call_max).await,
502 #[cfg(feature = "local-llm")]
503 RigAgent::Mistralrs {
504 agent,
505 tool_call_max,
506 } => run_turn_streaming_mistralrs(agent, prompt, history, *tool_call_max).await,
507 }
508 }
509}
510
511async fn run_turn_inner<M: CompletionModel + 'static>(
512 agent: &rig::agent::Agent<M>,
513 prompt: &str,
514 history: &mut Vec<Message>,
515 tool_call_max: usize,
516) -> Result<String> {
517 let hook = OutrigPromptHook::new(tool_call_max);
518 let result = agent
519 .prompt(prompt.to_string())
520 .with_history(history.clone())
521 .max_turns(tool_call_max)
522 .with_hook(hook)
523 .extended_details()
524 .await;
525
526 match result {
527 Ok(response) => {
528 let messages = response
529 .messages
530 .expect("rig populates messages on extended_details");
531 history.extend(messages);
532 Ok(response.output)
533 }
534 Err(other) => handle_prompt_error(other, history),
535 }
536}
537
538#[cfg(feature = "local-llm")]
539async fn run_turn_streaming_mistralrs(
540 agent: &rig::agent::Agent<crate::llm::mistralrs::MistralrsModel>,
541 prompt: &str,
542 history: &mut Vec<Message>,
543 tool_call_max: usize,
544) -> Result<String> {
545 let mut stdout = tokio::io::stdout();
546 run_turn_streaming_inner(agent, prompt, history, tool_call_max, &mut stdout).await
547}
548
549#[cfg(feature = "local-llm")]
550async fn run_turn_streaming_inner<M, W>(
551 agent: &rig::agent::Agent<M>,
552 prompt: &str,
553 history: &mut Vec<Message>,
554 tool_call_max: usize,
555 stdout: &mut W,
556) -> Result<String>
557where
558 M: CompletionModel + 'static,
559 W: AsyncWrite + Unpin,
560{
561 let hook = OutrigPromptHook::new(tool_call_max);
562 let mut stream = agent
563 .stream_prompt(prompt.to_string())
564 .with_history(history.clone())
565 .multi_turn(tool_call_max)
566 .with_hook(hook)
567 .await;
568
569 let mut streamed_reply = String::new();
570 let mut final_history: Option<Vec<Message>> = None;
571
572 while let Some(item) = stream.next().await {
573 match item {
574 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(text))) => {
575 stdout.write_all(text.text.as_bytes()).await?;
576 stdout.flush().await?;
577 streamed_reply.push_str(&text.text);
578 }
579 Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::ToolCall {
580 ..
581 })) => {
582 stdout.flush().await?;
583 }
584 Ok(MultiTurnStreamItem::FinalResponse(response)) => {
585 final_history = response.history().map(|messages| messages.to_vec());
586 }
587 Ok(_) => {}
588 Err(err) => {
589 return handle_streaming_error(err, history);
590 }
591 }
592 }
593
594 if let Some(messages) = final_history {
595 extend_history_with_new_suffix(history, messages);
596 }
597
598 if !streamed_reply.is_empty() && !streamed_reply.ends_with('\n') {
599 stdout.write_all(b"\n").await?;
600 stdout.flush().await?;
601 }
602
603 Ok(String::new())
604}
605
606#[cfg(feature = "local-llm")]
607fn handle_streaming_error(err: StreamingError, history: &mut Vec<Message>) -> Result<String> {
608 let prompt_error = match err {
609 StreamingError::Completion(err) => rig::completion::PromptError::CompletionError(err),
610 StreamingError::Prompt(err) => *err,
611 StreamingError::Tool(err) => rig::completion::PromptError::ToolError(err),
612 };
613 handle_prompt_error(prompt_error, history)
614}
615
616fn handle_prompt_error(
617 err: rig::completion::PromptError,
618 history: &mut Vec<Message>,
619) -> Result<String> {
620 match err {
621 rig::completion::PromptError::PromptCancelled {
622 reason,
623 chat_history,
624 } => {
625 eprintln!("[outrig] {reason}");
626 eprintln!(
627 "[outrig] partial history retained -- send another prompt \
628 (e.g. \"continue\") to keep going, or \"/reset\" to drop it."
629 );
630 extend_history_with_new_suffix(history, chat_history);
631 Ok("(turn ended; tool-call max reached)".to_string())
632 }
633 rig::completion::PromptError::MaxTurnsError {
634 max_turns,
635 chat_history,
636 ..
637 } => {
638 eprintln!("[outrig] tool-call iteration max ({max_turns}) reached; ending turn");
639 eprintln!(
640 "[outrig] partial history retained -- send another prompt \
641 (e.g. \"continue\") to keep going, or \"/reset\" to drop it."
642 );
643 extend_history_with_new_suffix(history, *chat_history);
644 Ok("(turn ended; tool-call max reached)".to_string())
645 }
646 other => Err(other.into()),
647 }
648}
649
650fn extend_history_with_new_suffix(history: &mut Vec<Message>, returned: Vec<Message>) {
651 let existing_len = history.len();
652 if returned.len() >= existing_len && returned[..existing_len] == history[..] {
653 history.extend(returned.into_iter().skip(existing_len));
654 } else {
655 history.extend(returned);
656 }
657}
658
659#[derive(Clone)]
663pub struct OutrigPromptHook {
664 counter: Arc<AtomicUsize>,
665 cap_reached: Arc<AtomicBool>,
666 max: usize,
667}
668
669impl OutrigPromptHook {
670 pub fn new(max: usize) -> Self {
671 Self {
672 counter: Arc::new(AtomicUsize::new(0)),
673 cap_reached: Arc::new(AtomicBool::new(false)),
674 max,
675 }
676 }
677}
678
679impl<M: CompletionModel> PromptHook<M> for OutrigPromptHook {
680 async fn on_completion_call(&self, _prompt: &Message, _history: &[Message]) -> HookAction {
681 if self.cap_reached.load(Ordering::SeqCst) {
682 return HookAction::terminate(format!(
683 "tool-call iteration max ({}) reached; ending turn",
684 self.max
685 ));
686 }
687 HookAction::cont()
688 }
689
690 async fn on_tool_call(
691 &self,
692 tool_name: &str,
693 _tool_call_id: Option<String>,
694 _internal_call_id: &str,
695 args: &str,
696 ) -> ToolCallHookAction {
697 let n = self.counter.fetch_add(1, Ordering::SeqCst) + 1;
698 if n > self.max {
699 self.cap_reached.store(true, Ordering::SeqCst);
700 return ToolCallHookAction::skip(format!(
701 "[outrig] tool call not executed: per-turn tool-call max ({}) \
702 was reached before this call could run. The user may continue \
703 with a fresh max; repeat the tool call if still needed.",
704 self.max
705 ));
706 }
707 eprintln!("[outrig] tool call: {tool_name}({args})");
708 ToolCallHookAction::cont()
709 }
710}
711
712fn finish_agent<M: rig::completion::CompletionModel + 'static>(
713 model: M,
714 resolved: &ResolvedAgent,
715 tools: Vec<McpToolAdapter>,
716) -> rig::agent::Agent<M> {
717 use rig::agent::AgentBuilder;
718 use rig::tool::ToolDyn;
719
720 let mut builder = AgentBuilder::new(model).preamble(&resolved.preamble);
721 if let Some(temperature) = resolved.temperature {
722 builder = builder.temperature(temperature as f64);
723 }
724 if let Some(max_tokens) = resolved.max_tokens {
725 builder = builder.max_tokens(max_tokens as u64);
726 }
727 let boxed: Vec<Box<dyn ToolDyn>> = tools
728 .into_iter()
729 .map(|t| Box::new(t) as Box<dyn ToolDyn>)
730 .collect();
731 builder.tools(boxed).build()
732}
733
734#[cfg(test)]
735mod tests {
736 use super::*;
737 #[cfg(feature = "local-llm")]
738 use rig::completion::{CompletionError, CompletionRequest, CompletionResponse, Usage};
739 #[cfg(feature = "local-llm")]
740 use rig::streaming::{RawStreamingChoice, StreamingCompletionResponse};
741
742 #[test]
743 fn cancelled_history_retains_only_new_suffix_when_full_history_returned() {
744 let original = vec![Message::user("first"), Message::assistant("done")];
745 let mut history = original.clone();
746 let mut returned = original;
747 returned.push(Message::user("second"));
748 returned.push(Message::assistant("partial"));
749
750 extend_history_with_new_suffix(&mut history, returned);
751
752 assert_eq!(
753 history,
754 vec![
755 Message::user("first"),
756 Message::assistant("done"),
757 Message::user("second"),
758 Message::assistant("partial"),
759 ],
760 );
761 }
762
763 #[test]
764 fn cancelled_history_appends_when_returned_history_is_only_partial() {
765 let mut history = vec![Message::user("first")];
766 let returned = vec![Message::assistant("partial")];
767
768 extend_history_with_new_suffix(&mut history, returned);
769
770 assert_eq!(
771 history,
772 vec![Message::user("first"), Message::assistant("partial")],
773 );
774 }
775
776 #[cfg(feature = "local-llm")]
777 #[derive(Clone)]
778 struct ScriptedStreamingModel {
779 chunks: Arc<Vec<RawStreamingChoice<()>>>,
780 }
781
782 #[cfg(feature = "local-llm")]
783 impl ScriptedStreamingModel {
784 fn new(chunks: Vec<RawStreamingChoice<()>>) -> Self {
785 Self {
786 chunks: Arc::new(chunks),
787 }
788 }
789 }
790
791 #[cfg(feature = "local-llm")]
792 impl CompletionModel for ScriptedStreamingModel {
793 type Response = ();
794 type StreamingResponse = ();
795 type Client = ();
796
797 fn make(_client: &Self::Client, _model: impl Into<String>) -> Self {
798 Self::new(Vec::new())
799 }
800
801 async fn completion(
802 &self,
803 _request: CompletionRequest,
804 ) -> std::result::Result<CompletionResponse<Self::Response>, CompletionError> {
805 Ok(CompletionResponse {
806 choice: rig::OneOrMany::one(rig::completion::AssistantContent::text("")),
807 usage: Usage::new(),
808 raw_response: (),
809 message_id: None,
810 })
811 }
812
813 async fn stream(
814 &self,
815 _request: CompletionRequest,
816 ) -> std::result::Result<
817 StreamingCompletionResponse<Self::StreamingResponse>,
818 CompletionError,
819 > {
820 let chunks = self.chunks.clone();
821 let stream = async_stream::try_stream! {
822 for chunk in chunks.iter().cloned() {
823 yield chunk;
824 }
825 };
826 Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
827 }
828 }
829
830 #[cfg(feature = "local-llm")]
831 #[tokio::test]
832 async fn streaming_turn_writes_chunks_once_and_retains_history() {
833 let model = ScriptedStreamingModel::new(vec![
834 RawStreamingChoice::Message("hello ".to_string()),
835 RawStreamingChoice::Message("world".to_string()),
836 ]);
837 let agent = rig::agent::AgentBuilder::new(model).build();
838 let mut history = Vec::new();
839 let mut stdout = Vec::new();
840
841 let reply = run_turn_streaming_inner(&agent, "hi", &mut history, 50, &mut stdout)
842 .await
843 .expect("streaming turn succeeds");
844
845 assert_eq!(reply, "");
846 assert_eq!(
847 String::from_utf8(stdout).expect("stdout utf-8"),
848 "hello world\n"
849 );
850 assert_eq!(
851 history,
852 vec![Message::user("hi"), Message::assistant("hello world")],
853 );
854 }
855}