ai_workbench_lib/modules/
model_runner.rs

1use anyhow::{Context, Result};
2use aws_sdk_bedrockruntime::{primitives::Blob, Client as BedrockClient};
3use serde_json::{json, Value};
4use std::{collections::HashMap, sync::Arc};
5use tracing::info;
6
7#[derive(Debug, Clone)]
8pub struct ModelPricing {
9    pub input: f64,
10    pub output: f64,
11}
12
13
14#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15pub struct TokenUsage {
16    pub input_tokens: u32,
17    pub output_tokens: u32,
18    pub total_tokens: u32,
19}
20
21impl std::fmt::Display for TokenUsage {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        write!(f, "{} total ({} input, {} output)", 
24               self.total_tokens, self.input_tokens, self.output_tokens)
25    }
26}
27
28pub struct ModelRunner {
29    bedrock_client: Arc<BedrockClient>,
30}
31
32impl ModelRunner {
33    pub fn new(bedrock_client: Arc<BedrockClient>) -> Self {
34        Self { bedrock_client }
35    }
36
37    /// Invoke a Bedrock model with the given prompt and model configuration
38    pub async fn invoke_model(&self, model_id: &str, prompt: &str, max_tokens: u32) -> Result<(String, Option<TokenUsage>)> {
39        // Truncate if too long (rough token estimate: ~4 chars per token)
40        let max_chars = 15000; // Leave room for response
41        let truncated_prompt = if prompt.len() > max_chars {
42            format!("{}...\n\n[Content truncated for analysis]", &prompt[..max_chars])
43        } else {
44            prompt.to_string()
45        };
46
47        // Build request body based on model type
48        let request_body = self.build_request_body(model_id, &truncated_prompt, max_tokens);
49        info!("Invoking model {} with {} character prompt", model_id, truncated_prompt.len());
50        
51        let response = self
52            .bedrock_client
53            .invoke_model()
54            .model_id(model_id)
55            .content_type("application/json")
56            .accept("application/json")
57            .body(Blob::new(request_body.to_string().as_bytes()))
58            .send()
59            .await
60            .map_err(|e| {
61                anyhow::anyhow!(
62                    "Failed to invoke Bedrock model {}: {}. This could be due to:\n\
63                    1. Model not available in your region\n\
64                    2. Insufficient permissions\n\
65                    3. Model access not enabled\n\
66                    4. Invalid model ID\n\
67                    Original error: {}", 
68                    model_id, e, e
69                )
70            })?;
71
72        let response_body: Value = serde_json::from_slice(&response.body().as_ref())
73            .context("Failed to parse Bedrock response")?;
74
75        // Log the raw response for debugging
76        info!("Bedrock response for model {}: {}", model_id, 
77              serde_json::to_string_pretty(&response_body).unwrap_or_else(|_| "Could not serialize response".to_string()));
78
79        // Extract output and token usage
80        let output = self.extract_output(model_id, &response_body)?;
81        let tokens_used = self.extract_token_usage(&response_body);
82
83        info!("Model {:?} invoked successfully, {:?} tokens used", model_id, tokens_used.clone().unwrap());
84
85        Ok((output, tokens_used))
86    }
87
88    /// Invoke a model with file content and analysis prompt
89    pub async fn invoke_model_with_file_content(
90        &self, 
91        model_id: &str, 
92        file_content: &str, 
93        analysis_prompt: &str, 
94        max_tokens: u32
95    ) -> Result<(String, Option<TokenUsage>)> {
96        // Log details about the inputs
97        info!("invoke_model_with_file_content called:");
98        info!("  - model_id: {}", model_id);
99        info!("  - file_content length: {} chars", file_content.len());
100        info!("  - analysis_prompt: {}", analysis_prompt);
101        info!("  - file_content preview (first 200 chars): {}", 
102              &file_content.chars().take(200).collect::<String>());
103        
104        // Create a comprehensive prompt that includes both the file content and the analysis request
105        let combined_prompt = self.format_prompt_with_file_content(file_content, analysis_prompt);
106        
107        // Log the combined prompt length for debugging
108        info!("Combined prompt length: {} chars", combined_prompt.len());
109        
110        // Use the standard invoke_model with the combined prompt
111        self.invoke_model(model_id, &combined_prompt, max_tokens).await
112    }
113
114    /// Format a prompt that includes file content and analysis instructions
115    fn format_prompt_with_file_content(&self, file_content: &str, analysis_prompt: &str) -> String {
116        // Add debugging info about content length
117        info!("Formatting prompt with file content length: {} chars", file_content.len());
118        
119        if file_content.trim().is_empty() {
120            return format!(
121                "ERROR: No file content provided. Unable to analyze empty content.\n\n\
122                USER REQUEST: {}\n\n\
123                Please ensure the file content is properly loaded before analysis.",
124                analysis_prompt
125            );
126        }
127
128        format!(
129            "You are analyzing the content of a file. The file content is provided below, followed by the analysis request.\n\n\
130            === FILE CONTENT START ===\n\
131            {}\n\
132            === FILE CONTENT END ===\n\n\
133            === ANALYSIS REQUEST ===\n\
134            {}\n\n\
135            === INSTRUCTIONS ===\n\
136            Please analyze the file content provided above and respond to the analysis request. \
137            Base your response ONLY on the content shown between the FILE CONTENT START/END markers. \
138            Do not provide generic guidance - analyze the specific content provided.",
139            file_content, analysis_prompt
140        )
141    }
142
143    /// Build request body based on model type
144    fn build_request_body(&self, model_id: &str, prompt: &str, max_tokens: u32) -> Value {
145        match model_id {
146            // Amazon Nova models (working in eu-west-2)
147            id if id.starts_with("amazon.nova") => {
148                json!({
149                    "messages": [
150                        {
151                            "role": "user",
152                            "content": [
153                                {
154                                    "text": prompt
155                                }
156                            ]
157                        }
158                    ],
159                    "inferenceConfig": {
160                        "max_new_tokens": max_tokens,
161                        "temperature": 0.1,
162                        "top_p": 0.9
163                    }
164                })
165            }
166            // Anthropic Claude 3 models (working in eu-west-2)
167            id if id.starts_with("anthropic.claude-3") => {
168                json!({
169                    "anthropic_version": "bedrock-2023-05-31",
170                    "max_tokens": max_tokens,
171                    "temperature": 0.1,
172                    "top_p": 0.9,
173                    "messages": [
174                        {
175                            "role": "user",
176                            "content": prompt
177                        }
178                    ]
179                })
180            }
181            // Mistral models (working in eu-west-2)
182            id if id.starts_with("mistral.") => {
183                json!({
184                    "prompt": prompt,
185                    "max_tokens": max_tokens,
186                    "temperature": 0.1,
187                    "top_p": 0.9
188                })
189            }
190            // Meta Llama models (working in eu-west-2)
191            id if id.starts_with("meta.llama") => {
192                json!({
193                    "prompt": prompt,
194                    "max_gen_len": max_tokens,
195                    "temperature": 0.1,
196                    "top_p": 0.9
197                })
198            }
199            // Fallback for older Claude models
200            id if id.starts_with("anthropic.claude") => {
201                json!({
202                    "prompt": format!("\n\nHuman: {}\n\nAssistant:", prompt),
203                    "max_tokens_to_sample": max_tokens,
204                    "temperature": 0.1,
205                    "top_p": 0.9,
206                })
207            }
208            // Amazon Titan models (if needed)
209            id if id.starts_with("amazon.titan") => {
210                json!({
211                    "inputText": prompt,
212                    "textGenerationConfig": {
213                        "maxTokenCount": max_tokens,
214                        "temperature": 0.1,
215                        "topP": 0.9
216                    }
217                })
218            }
219            // Default format for any other models
220            _ => {
221                json!({
222                    "prompt": prompt,
223                    "max_tokens": max_tokens,
224                    "temperature": 0.1
225                })
226            }
227        }
228    }
229
230    pub fn get_model_pricing() -> HashMap<&'static str, ModelPricing> {
231        let mut pricing = HashMap::new();
232        pricing.insert("amazon.nova-micro-v1:0", ModelPricing { input: 0.00035, output: 0.0014 });
233        pricing.insert("amazon.nova-lite-v1:0", ModelPricing { input: 0.0006, output: 0.0024 });
234        pricing.insert("amazon.nova-pro-v1:0", ModelPricing { input: 0.008, output: 0.032 });
235        pricing.insert("amazon.titan-text-lite-v1", ModelPricing { input: 0.0003, output: 0.0004 });
236        pricing.insert("amazon.titan-text-express-v1", ModelPricing { input: 0.0008, output: 0.0016 });
237        pricing.insert("anthropic.claude-3-haiku-20240307-v1:0", ModelPricing { input: 0.00025, output: 0.00125 });
238        pricing.insert("anthropic.claude-3-sonnet-20240229-v1:0", ModelPricing { input: 0.003, output: 0.015 });
239        pricing.insert("anthropic.claude-3-opus-20240229-v1:0", ModelPricing { input: 0.015, output: 0.075 });
240        pricing.insert("anthropic.claude-3-5-sonnet-20240620-v1:0", ModelPricing { input: 0.003, output: 0.015 });
241        pricing.insert("meta.llama3-8b-instruct-v1:0", ModelPricing { input: 0.0003, output: 0.0006 });
242        pricing.insert("meta.llama3-70b-instruct-v1:0", ModelPricing { input: 0.00265, output: 0.0035 });
243        pricing.insert("mistral.mistral-7b-instruct-v0:2", ModelPricing { input: 0.00015, output: 0.0002 });
244        pricing.insert("mistral.mixtral-8x7b-instruct-v0:1", ModelPricing { input: 0.00045, output: 0.0007 });
245        pricing.insert("mistral.mistral-large-2402-v1:0", ModelPricing { input: 0.004, output: 0.012 });
246        pricing.insert("cohere.command-text-v14", ModelPricing { input: 0.0015, output: 0.002 });
247        pricing.insert("cohere.command-light-text-v14", ModelPricing { input: 0.0003, output: 0.0006 });
248        pricing.insert("ai21.j2-mid-v1", ModelPricing { input: 0.0125, output: 0.0125 });
249        pricing.insert("ai21.j2-ultra-v1", ModelPricing { input: 0.0188, output: 0.0188 });
250        pricing
251    }
252
253        /// Calculates cost estimate based on token usage and model pricing
254    pub fn calculate_cost_estimate(token_usage: Option<TokenUsage>, model_id: &str) -> Option<f64> {
255        token_usage.and_then(|usage| {
256            let pricing_map = ModelRunner::get_model_pricing();
257            pricing_map.get(model_id).map(|pricing| {
258                (usage.input_tokens as f64 * pricing.input / 1000.0) + 
259                (usage.output_tokens as f64 * pricing.output / 1000.0)
260            })
261        })
262    }
263
264    /// Extract output from response based on model type
265    fn extract_output(&self, model_id: &str, response_body: &Value) -> Result<String> {
266        let output = match model_id {
267            // Amazon Nova models (working in eu-west-2)
268            id if id.starts_with("amazon.nova") => {
269                response_body["output"]["message"]["content"][0]["text"]
270                    .as_str()
271                    .unwrap_or("No content generated")
272                    .to_string()
273            }
274            // Anthropic Claude 3 models (working in eu-west-2)
275            id if id.starts_with("anthropic.claude-3") => {
276                response_body["content"][0]["text"]
277                    .as_str()
278                    .unwrap_or("No content generated")
279                    .to_string()
280            }
281            // Mistral models (working in eu-west-2)
282            id if id.starts_with("mistral.") => {
283                response_body["outputs"][0]["text"]
284                    .as_str()
285                    .unwrap_or("No content generated")
286                    .to_string()
287            }
288            // Meta Llama models (working in eu-west-2)
289            id if id.starts_with("meta.llama") => {
290                response_body["generation"]
291                    .as_str()
292                    .unwrap_or("No content generated")
293                    .to_string()
294            }
295            // Older Claude models
296            id if id.starts_with("anthropic.claude") => {
297                response_body["completion"]
298                    .as_str()
299                    .unwrap_or("No content generated")
300                    .to_string()
301            }
302            // Amazon Titan models (if needed)
303            id if id.starts_with("amazon.titan") => {
304                response_body["results"][0]["outputText"]
305                    .as_str()
306                    .unwrap_or("No content generated")
307                    .to_string()
308            }
309            // Fallback for other models - try multiple common response formats
310            _ => {
311                response_body["completion"]
312                    .as_str()
313                    .or_else(|| response_body["text"].as_str())
314                    .or_else(|| response_body["outputs"][0]["text"].as_str())
315                    .or_else(|| response_body["content"][0]["text"].as_str())
316                    .or_else(|| response_body["generation"].as_str())
317                    .or_else(|| response_body["results"][0]["outputText"].as_str())
318                    .or_else(|| response_body["output"]["message"]["content"][0]["text"].as_str())
319                    .unwrap_or("No content generated")
320                    .to_string()
321            }
322        };
323
324        if output.trim().is_empty() || output == "No content generated" {
325            info!("Model response body for debugging: {}", serde_json::to_string_pretty(response_body).unwrap_or_else(|_| "Could not serialize response".to_string()));
326        }
327
328        Ok(output)
329    }
330
331    /// Extract token usage from response
332    fn extract_token_usage(&self, response_body: &Value) -> Option<TokenUsage> {
333        // Try to extract separate input/output tokens first
334        if let (Some(input), Some(output)) = (
335            response_body["usage"]["inputTokens"].as_u64(),
336            response_body["usage"]["outputTokens"].as_u64(),
337        ) {
338            return Some(TokenUsage {
339                input_tokens: input as u32,
340                output_tokens: output as u32,
341                total_tokens: (input + output) as u32,
342            });
343        }
344
345        // Fallback to total tokens if available
346        if let Some(total) = response_body["usage"]["totalTokens"].as_u64()
347            .or_else(|| response_body["usage"]["total_tokens"].as_u64())
348        {
349            // When only total is available, assume 70% input / 30% output split as rough estimate
350            let input_estimate = (total as f64 * 0.7) as u32;
351            let output_estimate = total as u32 - input_estimate;
352            return Some(TokenUsage {
353                input_tokens: input_estimate,
354                output_tokens: output_estimate,
355                total_tokens: total as u32,
356            });
357        }
358
359        // Last resort: try Bedrock invocation metrics
360        if let (Some(input), Some(output)) = (
361            response_body["amazon-bedrock-invocationMetrics"]["inputTokenCount"].as_u64(),
362            response_body["amazon-bedrock-invocationMetrics"]["outputTokenCount"].as_u64(),
363        ) {
364            return Some(TokenUsage {
365                input_tokens: input as u32,
366                output_tokens: output as u32,
367                total_tokens: (input + output) as u32,
368            });
369        }
370
371        None
372    }
373}