aidale_core/runtime/
executor.rs1use crate::error::AiError;
8use crate::layer::Layer;
9use crate::plugin::{Plugin, PluginEngine};
10use crate::provider::Provider;
11use crate::strategy::{detect_json_strategy, JsonOutputStrategy};
12use crate::types::*;
13use std::sync::Arc;
14
15type BoxedProvider = Arc<dyn Provider>;
17
18pub struct RuntimeExecutorBuilder<P> {
34 provider: P,
35 plugins: Vec<Arc<dyn Plugin>>,
36 json_strategy: Option<Box<dyn JsonOutputStrategy>>,
37}
38
39impl<P: Provider> RuntimeExecutorBuilder<P> {
40 pub fn new(provider: P) -> Self {
42 Self {
43 provider,
44 plugins: Vec::new(),
45 json_strategy: None,
46 }
47 }
48
49 pub fn layer<L>(self, layer: L) -> RuntimeExecutorBuilder<L::LayeredProvider>
54 where
55 L: Layer<P>,
56 {
57 RuntimeExecutorBuilder {
58 provider: layer.layer(self.provider),
59 plugins: self.plugins,
60 json_strategy: self.json_strategy,
61 }
62 }
63
64 pub fn plugin(mut self, plugin: Arc<dyn Plugin>) -> Self {
66 self.plugins.push(plugin);
67 self
68 }
69
70 pub fn json_strategy(mut self, strategy: Box<dyn JsonOutputStrategy>) -> Self {
74 self.json_strategy = Some(strategy);
75 self
76 }
77
78 pub fn finish(self) -> RuntimeExecutor {
80 let provider = Arc::new(self.provider);
81 let provider_id = provider.info().id.clone();
82
83 let json_strategy = self
85 .json_strategy
86 .unwrap_or_else(|| detect_json_strategy(&provider_id));
87
88 RuntimeExecutor {
89 provider,
90 plugin_engine: PluginEngine::new(self.plugins),
91 json_strategy,
92 }
93 }
94}
95
96pub struct RuntimeExecutor {
102 provider: BoxedProvider,
103 plugin_engine: PluginEngine,
104 json_strategy: Box<dyn JsonOutputStrategy>,
105}
106
107impl RuntimeExecutor {
108 pub fn builder<P: Provider>(provider: P) -> RuntimeExecutorBuilder<P> {
110 RuntimeExecutorBuilder::new(provider)
111 }
112
113 pub fn info(&self) -> Arc<ProviderInfo> {
115 self.provider.info()
116 }
117
118 pub fn plugin_engine(&self) -> &PluginEngine {
120 &self.plugin_engine
121 }
122
123 pub async fn generate_text(
128 &self,
129 model: impl Into<String>,
130 params: TextParams,
131 ) -> Result<TextResult, AiError> {
132 let model = model.into();
133 let provider_info = self.provider.info();
134
135 let ctx = RequestContext::new(provider_info.id.clone(), model.clone());
137
138 let resolved_model = self.plugin_engine.resolve_model(&model, &ctx).await?;
140
141 let transformed_params = self.plugin_engine.transform_params(params, &ctx).await?;
143
144 self.plugin_engine.on_request_start(&ctx).await?;
146
147 let chat_req = ChatCompletionRequest {
149 model: resolved_model.clone(),
150 messages: transformed_params.messages,
151 temperature: transformed_params.temperature,
152 max_tokens: transformed_params.max_tokens,
153 top_p: transformed_params.top_p,
154 frequency_penalty: transformed_params.frequency_penalty,
155 presence_penalty: transformed_params.presence_penalty,
156 stop: transformed_params.stop,
157 tools: transformed_params.tools,
158 response_format: Some(ResponseFormat::Text),
159 stream: Some(false),
160 extra: transformed_params.extra,
161 };
162
163 let result = self.provider.chat_completion(chat_req).await;
165
166 match result {
167 Ok(response) => {
168 let first_choice = response
170 .choices
171 .first()
172 .ok_or_else(|| AiError::provider("No choices in response"))?;
173
174 let content = first_choice
175 .message
176 .content
177 .iter()
178 .filter_map(|part| match part {
179 ContentPart::Text { text } => Some(text.as_str()),
180 _ => None,
181 })
182 .collect::<Vec<_>>()
183 .join("");
184
185 let mut result = TextResult {
186 content,
187 finish_reason: first_choice.finish_reason.clone(),
188 usage: response.usage,
189 model: response.model,
190 tool_calls: None,
191 };
192
193 result = self.plugin_engine.transform_result(result, &ctx).await?;
195
196 self.plugin_engine.on_request_end(&ctx, &result).await?;
198
199 Ok(result)
200 }
201 Err(err) => {
202 let _ = self.plugin_engine.on_error(&err, &ctx).await;
204 Err(err)
205 }
206 }
207 }
208
209 pub async fn generate_object(
215 &self,
216 model: impl Into<String>,
217 params: ObjectParams,
218 ) -> Result<ObjectResult, AiError> {
219 let model = model.into();
220
221 let mut chat_req = ChatCompletionRequest {
223 model: model.clone(),
224 messages: params.messages,
225 temperature: params.temperature,
226 max_tokens: params.max_tokens,
227 top_p: None,
228 frequency_penalty: None,
229 presence_penalty: None,
230 stop: None,
231 tools: None,
232 response_format: None, stream: Some(false),
234 extra: std::collections::HashMap::new(),
235 };
236
237 self.json_strategy.apply(&mut chat_req, ¶ms.schema)?;
239
240 let response = self.provider.chat_completion(chat_req).await?;
242
243 let first_choice = response
245 .choices
246 .first()
247 .ok_or_else(|| AiError::provider("No choices in response"))?;
248
249 let content = first_choice
250 .message
251 .content
252 .iter()
253 .filter_map(|part| match part {
254 ContentPart::Text { text } => Some(text.as_str()),
255 _ => None,
256 })
257 .collect::<Vec<_>>()
258 .join("");
259
260 let object: serde_json::Value = serde_json::from_str(&content)?;
262
263 Ok(ObjectResult {
264 object,
265 usage: response.usage,
266 model: response.model,
267 })
268 }
269}