aidale_core/runtime/
executor.rs

1//! RuntimeExecutor implementation.
2//!
3//! This module implements the RuntimeExecutor, which provides high-level
4//! generate_text() and generate_object() APIs by orchestrating provider
5//! chat completion calls with strategy selection.
6
7use crate::error::AiError;
8use crate::layer::Layer;
9use crate::plugin::{Plugin, PluginEngine};
10use crate::provider::Provider;
11use crate::strategy::{detect_json_strategy, JsonOutputStrategy};
12use crate::types::*;
13use std::sync::Arc;
14
15/// Type-erased provider that can be shared across threads
16type BoxedProvider = Arc<dyn Provider>;
17
18/// Builder for composing AI providers with layers and plugins.
19///
20/// This builder allows for flexible composition following OpenDAL's pattern:
21/// - Layers wrap the provider (static dispatch during building)
22/// - Plugins extend the runtime (stored for execution)
23///
24/// # Example
25///
26/// ```ignore
27/// let executor = RuntimeExecutor::builder(openai_provider)
28///     .layer(LoggingLayer::new())
29///     .layer(RetryLayer::new())
30///     .plugin(Arc::new(ToolUsePlugin::new()))
31///     .finish();
32/// ```
33pub struct RuntimeExecutorBuilder<P> {
34    provider: P,
35    plugins: Vec<Arc<dyn Plugin>>,
36    json_strategy: Option<Box<dyn JsonOutputStrategy>>,
37}
38
39impl<P: Provider> RuntimeExecutorBuilder<P> {
40    /// Create a new builder with a provider
41    pub fn new(provider: P) -> Self {
42        Self {
43            provider,
44            plugins: Vec::new(),
45            json_strategy: None,
46        }
47    }
48
49    /// Add a layer to wrap the provider
50    ///
51    /// This uses static dispatch - each call to `layer()` creates a new
52    /// concrete type by wrapping the previous provider.
53    pub fn layer<L>(self, layer: L) -> RuntimeExecutorBuilder<L::LayeredProvider>
54    where
55        L: Layer<P>,
56    {
57        RuntimeExecutorBuilder {
58            provider: layer.layer(self.provider),
59            plugins: self.plugins,
60            json_strategy: self.json_strategy,
61        }
62    }
63
64    /// Add a plugin to the runtime
65    pub fn plugin(mut self, plugin: Arc<dyn Plugin>) -> Self {
66        self.plugins.push(plugin);
67        self
68    }
69
70    /// Set a custom JSON output strategy
71    ///
72    /// If not set, the strategy will be auto-detected based on the provider ID.
73    pub fn json_strategy(mut self, strategy: Box<dyn JsonOutputStrategy>) -> Self {
74        self.json_strategy = Some(strategy);
75        self
76    }
77
78    /// Finish building and create a RuntimeExecutor
79    pub fn finish(self) -> RuntimeExecutor {
80        let provider = Arc::new(self.provider);
81        let provider_id = provider.info().id.clone();
82
83        // Auto-detect strategy if not provided
84        let json_strategy = self
85            .json_strategy
86            .unwrap_or_else(|| detect_json_strategy(&provider_id));
87
88        RuntimeExecutor {
89            provider,
90            plugin_engine: PluginEngine::new(self.plugins),
91            json_strategy,
92        }
93    }
94}
95
96/// Runtime executor with plugin support.
97///
98/// This is the main entry point for making AI requests. It provides high-level
99/// APIs (generate_text, generate_object) that internally use the provider's
100/// chat_completion API with appropriate strategy selection.
101pub struct RuntimeExecutor {
102    provider: BoxedProvider,
103    plugin_engine: PluginEngine,
104    json_strategy: Box<dyn JsonOutputStrategy>,
105}
106
107impl RuntimeExecutor {
108    /// Create a new builder
109    pub fn builder<P: Provider>(provider: P) -> RuntimeExecutorBuilder<P> {
110        RuntimeExecutorBuilder::new(provider)
111    }
112
113    /// Get provider information
114    pub fn info(&self) -> Arc<ProviderInfo> {
115        self.provider.info()
116    }
117
118    /// Get reference to the plugin engine
119    pub fn plugin_engine(&self) -> &PluginEngine {
120        &self.plugin_engine
121    }
122
123    /// Generate text using chat completion
124    ///
125    /// This is a high-level API that converts the request to a chat completion
126    /// request and extracts the text content from the response.
127    pub async fn generate_text(
128        &self,
129        model: impl Into<String>,
130        params: TextParams,
131    ) -> Result<TextResult, AiError> {
132        let model = model.into();
133        let provider_info = self.provider.info();
134
135        // Create request context
136        let ctx = RequestContext::new(provider_info.id.clone(), model.clone());
137
138        // Resolve model through plugins
139        let resolved_model = self.plugin_engine.resolve_model(&model, &ctx).await?;
140
141        // Transform params through plugins
142        let transformed_params = self.plugin_engine.transform_params(params, &ctx).await?;
143
144        // Fire on_request_start hooks
145        self.plugin_engine.on_request_start(&ctx).await?;
146
147        // Convert to chat completion request
148        let chat_req = ChatCompletionRequest {
149            model: resolved_model.clone(),
150            messages: transformed_params.messages,
151            temperature: transformed_params.temperature,
152            max_tokens: transformed_params.max_tokens,
153            top_p: transformed_params.top_p,
154            frequency_penalty: transformed_params.frequency_penalty,
155            presence_penalty: transformed_params.presence_penalty,
156            stop: transformed_params.stop,
157            tools: transformed_params.tools,
158            response_format: Some(ResponseFormat::Text),
159            stream: Some(false),
160            extra: transformed_params.extra,
161        };
162
163        // Make the actual request
164        let result = self.provider.chat_completion(chat_req).await;
165
166        match result {
167            Ok(response) => {
168                // Convert ChatCompletionResponse to TextResult
169                let first_choice = response
170                    .choices
171                    .first()
172                    .ok_or_else(|| AiError::provider("No choices in response"))?;
173
174                let content = first_choice
175                    .message
176                    .content
177                    .iter()
178                    .filter_map(|part| match part {
179                        ContentPart::Text { text } => Some(text.as_str()),
180                        _ => None,
181                    })
182                    .collect::<Vec<_>>()
183                    .join("");
184
185                let mut result = TextResult {
186                    content,
187                    finish_reason: first_choice.finish_reason.clone(),
188                    usage: response.usage,
189                    model: response.model,
190                    tool_calls: None,
191                };
192
193                // Transform result through plugins
194                result = self.plugin_engine.transform_result(result, &ctx).await?;
195
196                // Fire on_request_end hooks
197                self.plugin_engine.on_request_end(&ctx, &result).await?;
198
199                Ok(result)
200            }
201            Err(err) => {
202                // Fire on_error hooks
203                let _ = self.plugin_engine.on_error(&err, &ctx).await;
204                Err(err)
205            }
206        }
207    }
208
209    /// Generate object using chat completion with JSON output
210    ///
211    /// This is a high-level API that handles provider-specific JSON output strategies.
212    /// It automatically selects the appropriate strategy (JSON Schema or JSON Mode)
213    /// based on the provider capabilities.
214    pub async fn generate_object(
215        &self,
216        model: impl Into<String>,
217        params: ObjectParams,
218    ) -> Result<ObjectResult, AiError> {
219        let model = model.into();
220
221        // Convert to chat completion request
222        let mut chat_req = ChatCompletionRequest {
223            model: model.clone(),
224            messages: params.messages,
225            temperature: params.temperature,
226            max_tokens: params.max_tokens,
227            top_p: None,
228            frequency_penalty: None,
229            presence_penalty: None,
230            stop: None,
231            tools: None,
232            response_format: None, // Will be set by strategy
233            stream: Some(false),
234            extra: std::collections::HashMap::new(),
235        };
236
237        // Apply JSON output strategy
238        self.json_strategy.apply(&mut chat_req, &params.schema)?;
239
240        // Make the actual request
241        let response = self.provider.chat_completion(chat_req).await?;
242
243        // Extract JSON object from response
244        let first_choice = response
245            .choices
246            .first()
247            .ok_or_else(|| AiError::provider("No choices in response"))?;
248
249        let content = first_choice
250            .message
251            .content
252            .iter()
253            .filter_map(|part| match part {
254                ContentPart::Text { text } => Some(text.as_str()),
255                _ => None,
256            })
257            .collect::<Vec<_>>()
258            .join("");
259
260        // Parse JSON content
261        let object: serde_json::Value = serde_json::from_str(&content)?;
262
263        Ok(ObjectResult {
264            object,
265            usage: response.usage,
266            model: response.model,
267        })
268    }
269}