1use futures::{Stream, StreamExt};
2use serde::de::DeserializeOwned;
3use std::pin::Pin;
4
5use rig::agent::{Agent, MultiTurnStreamItem};
6use rig::completion::{Message, Prompt, PromptError};
7use rig::providers::{
8 anthropic, azure, cohere, deepseek, galadriel, gemini, groq, huggingface, hyperbolic, mira,
9 mistral, moonshot, ollama, openai, openrouter, perplexity, together, xai,
10};
11use rig::streaming::StreamingPrompt;
12use crate::extra_providers::bigmodel;
14
15#[derive(Clone)]
17pub enum AgentVariant {
18 OpenAI(Agent<openai::completion::CompletionModel>),
19 ResponsesOpenAI(Agent<openai::responses_api::ResponsesCompletionModel>),
20 Ollama(Agent<ollama::CompletionModel>),
21 Bigmodel(Agent<bigmodel::CompletionModel>),
22 OpenRouter(Agent<openrouter::completion::CompletionModel>),
23 Anthropic(Agent<anthropic::completion::CompletionModel>),
24 Cohere(Agent<cohere::completion::CompletionModel>),
25 Gemini(Agent<gemini::completion::CompletionModel>),
26 Huggingface(Agent<huggingface::completion::CompletionModel>),
27 Mistral(Agent<mistral::completion::CompletionModel>),
28 Together(Agent<together::completion::CompletionModel>),
29 XAI(Agent<xai::completion::CompletionModel>),
30 Azure(Agent<azure::CompletionModel>),
31 DeepSeek(Agent<deepseek::CompletionModel>),
32 Galadriel(Agent<galadriel::CompletionModel>),
33 Groq(Agent<groq::CompletionModel>),
34 Hyperbolic(Agent<hyperbolic::CompletionModel>),
35 Mira(Agent<mira::CompletionModel>),
36 Mooshot(Agent<moonshot::CompletionModel>),
37 Perplexity(Agent<perplexity::CompletionModel>),
38}
39
40pub type AnyMultiStream<R> =
42 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, PromptError>> + Send>>;
43
44fn map_stream<SR, R, E>(
46 stream: Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<SR>, E>> + Send>>,
47) -> AnyMultiStream<R>
48where
49 SR: serde::Serialize + Send + 'static,
50 R: DeserializeOwned + Send + 'static,
51 E: std::error::Error + Send + Sync + 'static,
52{
53 Box::pin(stream.map(move |item| {
54 item.map_err(|e| {
55 let error_msg = format!("stream error: {}", e);
57 PromptError::MaxDepthError {
58 max_depth: 0,
59 chat_history: Box::new(vec![]),
60 prompt: error_msg.into(),
61 }
62 })
63 .and_then(|multi_item| {
64 serde_json::to_value(&multi_item)
67 .map_err(|e| PromptError::MaxDepthError {
68 max_depth: 0,
69 chat_history: Box::new(vec![]),
70 prompt: format!("serialization error: {}", e).into(),
71 })
72 .and_then(|val| {
73 serde_json::from_value(val).map_err(|e| PromptError::MaxDepthError {
74 max_depth: 0,
75 chat_history: Box::new(vec![]),
76 prompt: format!("deserialization error: {}", e).into(),
77 })
78 })
79 })
80 }))
81}
82
83impl AgentVariant {
87 pub async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
88 match self {
89 AgentVariant::OpenAI(a) => a.prompt(prompt).await,
90 AgentVariant::ResponsesOpenAI(a) => a.prompt(prompt).await,
91 AgentVariant::Ollama(a) => a.prompt(prompt).await,
92 AgentVariant::Bigmodel(a) => a.prompt(prompt).await,
93 AgentVariant::OpenRouter(a) => a.prompt(prompt).await,
94 AgentVariant::Anthropic(a) => a.prompt(prompt).await,
95 AgentVariant::Cohere(a) => a.prompt(prompt).await,
96 AgentVariant::Gemini(a) => a.prompt(prompt).await,
97 AgentVariant::Huggingface(a) => a.prompt(prompt).await,
98 AgentVariant::Mistral(a) => a.prompt(prompt).await,
99 AgentVariant::Together(a) => a.prompt(prompt).await,
100 AgentVariant::XAI(a) => a.prompt(prompt).await,
101 AgentVariant::Azure(a) => a.prompt(prompt).await,
102 AgentVariant::DeepSeek(a) => a.prompt(prompt).await,
103 AgentVariant::Galadriel(a) => a.prompt(prompt).await,
104 AgentVariant::Groq(a) => a.prompt(prompt).await,
105 AgentVariant::Hyperbolic(a) => a.prompt(prompt).await,
106 AgentVariant::Mira(a) => a.prompt(prompt).await,
107 AgentVariant::Mooshot(a) => a.prompt(prompt).await,
108 AgentVariant::Perplexity(a) => a.prompt(prompt).await,
109 }
110 }
111}
112
113impl AgentVariant {
117 pub async fn stream_prompt<R>(
118 &self,
119 prompt: impl Into<Message> + Send,
120 ) -> Result<AnyMultiStream<R>, PromptError>
121 where
122 R: DeserializeOwned + Send + 'static,
123 {
124 async fn handle<M, R>(
125 agent: &Agent<M>,
126 prompt: impl Into<Message> + Send,
127 ) -> Result<AnyMultiStream<R>, PromptError>
128 where
129 M: rig::completion::CompletionModel + Send + Sync + 'static,
130 <M as rig::completion::CompletionModel>::StreamingResponse:
131 serde::Serialize + Send + 'static,
132 R: DeserializeOwned + Send + 'static,
133 {
134 let raw_stream = agent.stream_prompt(prompt).await;
136
137 Ok(map_stream(raw_stream))
138 }
139
140 match self {
141 AgentVariant::OpenAI(a) => handle::<_, R>(a, prompt).await,
142 AgentVariant::ResponsesOpenAI(a) => handle::<_, R>(a, prompt).await,
143 AgentVariant::Ollama(a) => handle::<_, R>(a, prompt).await,
144 AgentVariant::Bigmodel(a) => handle::<_, R>(a, prompt).await,
145 AgentVariant::OpenRouter(a) => handle::<_, R>(a, prompt).await,
146 AgentVariant::Anthropic(a) => handle::<_, R>(a, prompt).await,
147 AgentVariant::Cohere(a) => handle::<_, R>(a, prompt).await,
148 AgentVariant::Gemini(a) => handle::<_, R>(a, prompt).await,
149 AgentVariant::Huggingface(a) => handle::<_, R>(a, prompt).await,
150 AgentVariant::Mistral(a) => handle::<_, R>(a, prompt).await,
151 AgentVariant::Together(a) => handle::<_, R>(a, prompt).await,
152 AgentVariant::XAI(a) => handle::<_, R>(a, prompt).await,
153 AgentVariant::Azure(a) => handle::<_, R>(a, prompt).await,
154 AgentVariant::DeepSeek(a) => handle::<_, R>(a, prompt).await,
155 AgentVariant::Galadriel(a) => handle::<_, R>(a, prompt).await,
156 AgentVariant::Groq(a) => handle::<_, R>(a, prompt).await,
157 AgentVariant::Hyperbolic(a) => handle::<_, R>(a, prompt).await,
158 AgentVariant::Mira(a) => handle::<_, R>(a, prompt).await,
159 AgentVariant::Mooshot(a) => handle::<_, R>(a, prompt).await,
160 AgentVariant::Perplexity(a) => handle::<_, R>(a, prompt).await,
161 }
162 }
163}