1use futures::{Stream, StreamExt};
2use serde::de::DeserializeOwned;
3use std::pin::Pin;
4
5use rig::agent::{Agent, MultiTurnStreamItem};
6use rig::completion::{CompletionError, Message, Prompt, PromptError};
7use rig::providers::{
8 anthropic, azure, cohere, deepseek, galadriel, gemini, groq, huggingface, hyperbolic, mira,
9 mistral, moonshot, ollama, openai, openrouter, perplexity, together, xai,
10};
11use rig::streaming::StreamingPrompt;
12use crate::extra_providers::bigmodel;
14
15#[derive(Clone)]
17pub enum AgentVariant {
18 OpenAI(Agent<openai::completion::CompletionModel>),
19 ResponsesOpenAI(Agent<openai::responses_api::ResponsesCompletionModel>),
20 Ollama(Agent<ollama::CompletionModel>),
21 Bigmodel(Agent<bigmodel::CompletionModel>),
22 OpenRouter(Agent<openrouter::completion::CompletionModel>),
23 Anthropic(Agent<anthropic::completion::CompletionModel>),
24 Cohere(Agent<cohere::completion::CompletionModel>),
25 Gemini(Agent<gemini::completion::CompletionModel>),
26 Huggingface(Agent<huggingface::completion::CompletionModel>),
27 Mistral(Agent<mistral::completion::CompletionModel>),
28 Together(Agent<together::completion::CompletionModel>),
29 XAI(Agent<xai::completion::CompletionModel>),
30 Azure(Agent<azure::CompletionModel>),
31 DeepSeek(Agent<deepseek::CompletionModel>),
32 Galadriel(Agent<galadriel::CompletionModel>),
33 Groq(Agent<groq::CompletionModel>),
34 Hyperbolic(Agent<hyperbolic::CompletionModel>),
35 Mira(Agent<mira::CompletionModel>),
36 Mooshot(Agent<moonshot::CompletionModel>),
37 Perplexity(Agent<perplexity::CompletionModel>),
38}
39
40pub type AnyMultiStream<R> =
42 Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, PromptError>> + Send>>;
43
44fn map_stream<SR, R, E>(
46 stream: Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<SR>, E>> + Send>>,
47) -> AnyMultiStream<R>
48where
49 SR: serde::Serialize + Send + 'static,
50 R: DeserializeOwned + Send + 'static,
51 E: std::error::Error + Send + Sync + 'static,
52{
53 Box::pin(stream.map(move |item| {
54 item.map_err(|e| {
55 PromptError::CompletionError(CompletionError::ProviderError(e.to_string()))
57 })
58 .and_then(|multi_item| {
59 serde_json::to_value(&multi_item)
62 .map_err(|e| {
63 PromptError::CompletionError(CompletionError::ProviderError(format!(
64 "serialization error: {}",
65 e
66 )))
67 })
68 .and_then(|val| {
69 serde_json::from_value(val).map_err(|e| {
70 PromptError::CompletionError(CompletionError::ProviderError(format!(
71 "deserialization error: {}",
72 e
73 )))
74 })
75 })
76 })
77 }))
78}
79
80impl AgentVariant {
84 pub async fn prompt(&self, prompt: impl Into<Message> + Send) -> Result<String, PromptError> {
85 match self {
86 AgentVariant::OpenAI(a) => a.prompt(prompt).await,
87 AgentVariant::ResponsesOpenAI(a) => a.prompt(prompt).await,
88 AgentVariant::Ollama(a) => a.prompt(prompt).await,
89 AgentVariant::Bigmodel(a) => a.prompt(prompt).await,
90 AgentVariant::OpenRouter(a) => a.prompt(prompt).await,
91 AgentVariant::Anthropic(a) => a.prompt(prompt).await,
92 AgentVariant::Cohere(a) => a.prompt(prompt).await,
93 AgentVariant::Gemini(a) => a.prompt(prompt).await,
94 AgentVariant::Huggingface(a) => a.prompt(prompt).await,
95 AgentVariant::Mistral(a) => a.prompt(prompt).await,
96 AgentVariant::Together(a) => a.prompt(prompt).await,
97 AgentVariant::XAI(a) => a.prompt(prompt).await,
98 AgentVariant::Azure(a) => a.prompt(prompt).await,
99 AgentVariant::DeepSeek(a) => a.prompt(prompt).await,
100 AgentVariant::Galadriel(a) => a.prompt(prompt).await,
101 AgentVariant::Groq(a) => a.prompt(prompt).await,
102 AgentVariant::Hyperbolic(a) => a.prompt(prompt).await,
103 AgentVariant::Mira(a) => a.prompt(prompt).await,
104 AgentVariant::Mooshot(a) => a.prompt(prompt).await,
105 AgentVariant::Perplexity(a) => a.prompt(prompt).await,
106 }
107 }
108}
109
110impl AgentVariant {
114 pub async fn stream_prompt<R>(
115 &self,
116 prompt: impl Into<Message> + Send,
117 ) -> Result<AnyMultiStream<R>, PromptError>
118 where
119 R: DeserializeOwned + Send + 'static,
120 {
121 async fn handle<M, R>(
122 agent: &Agent<M>,
123 prompt: impl Into<Message> + Send,
124 ) -> Result<AnyMultiStream<R>, PromptError>
125 where
126 M: rig::completion::CompletionModel + Send + Sync + 'static,
127 <M as rig::completion::CompletionModel>::StreamingResponse:
128 serde::Serialize + Send + 'static,
129 R: DeserializeOwned + Send + 'static,
130 {
131 let raw_stream = agent.stream_prompt(prompt).await;
133
134 Ok(map_stream(raw_stream))
135 }
136
137 match self {
138 AgentVariant::OpenAI(a) => handle::<_, R>(a, prompt).await,
139 AgentVariant::ResponsesOpenAI(a) => handle::<_, R>(a, prompt).await,
140 AgentVariant::Ollama(a) => handle::<_, R>(a, prompt).await,
141 AgentVariant::Bigmodel(a) => handle::<_, R>(a, prompt).await,
142 AgentVariant::OpenRouter(a) => handle::<_, R>(a, prompt).await,
143 AgentVariant::Anthropic(a) => handle::<_, R>(a, prompt).await,
144 AgentVariant::Cohere(a) => handle::<_, R>(a, prompt).await,
145 AgentVariant::Gemini(a) => handle::<_, R>(a, prompt).await,
146 AgentVariant::Huggingface(a) => handle::<_, R>(a, prompt).await,
147 AgentVariant::Mistral(a) => handle::<_, R>(a, prompt).await,
148 AgentVariant::Together(a) => handle::<_, R>(a, prompt).await,
149 AgentVariant::XAI(a) => handle::<_, R>(a, prompt).await,
150 AgentVariant::Azure(a) => handle::<_, R>(a, prompt).await,
151 AgentVariant::DeepSeek(a) => handle::<_, R>(a, prompt).await,
152 AgentVariant::Galadriel(a) => handle::<_, R>(a, prompt).await,
153 AgentVariant::Groq(a) => handle::<_, R>(a, prompt).await,
154 AgentVariant::Hyperbolic(a) => handle::<_, R>(a, prompt).await,
155 AgentVariant::Mira(a) => handle::<_, R>(a, prompt).await,
156 AgentVariant::Mooshot(a) => handle::<_, R>(a, prompt).await,
157 AgentVariant::Perplexity(a) => handle::<_, R>(a, prompt).await,
158 }
159 }
160}