ai_workbench_lib/modules/
model_runner.rs1use anyhow::{Context, Result};
2use aws_sdk_bedrockruntime::{primitives::Blob, Client as BedrockClient};
3use serde_json::{json, Value};
4use std::{collections::HashMap, sync::Arc};
5use tracing::info;
6
7#[derive(Debug, Clone)]
8pub struct ModelPricing {
9 pub input: f64,
10 pub output: f64,
11}
12
13
14#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15pub struct TokenUsage {
16 pub input_tokens: u32,
17 pub output_tokens: u32,
18 pub total_tokens: u32,
19}
20
21impl std::fmt::Display for TokenUsage {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 write!(f, "{} total ({} input, {} output)",
24 self.total_tokens, self.input_tokens, self.output_tokens)
25 }
26}
27
28pub struct ModelRunner {
29 bedrock_client: Arc<BedrockClient>,
30}
31
32impl ModelRunner {
33 pub fn new(bedrock_client: Arc<BedrockClient>) -> Self {
34 Self { bedrock_client }
35 }
36
37 pub async fn invoke_model(&self, model_id: &str, prompt: &str, max_tokens: u32) -> Result<(String, Option<TokenUsage>)> {
39 let max_chars = 15000; let truncated_prompt = if prompt.len() > max_chars {
42 format!("{}...\n\n[Content truncated for analysis]", &prompt[..max_chars])
43 } else {
44 prompt.to_string()
45 };
46
47 let request_body = self.build_request_body(model_id, &truncated_prompt, max_tokens);
49 info!("Invoking model {} with {} character prompt", model_id, truncated_prompt.len());
50
51 let response = self
52 .bedrock_client
53 .invoke_model()
54 .model_id(model_id)
55 .content_type("application/json")
56 .accept("application/json")
57 .body(Blob::new(request_body.to_string().as_bytes()))
58 .send()
59 .await
60 .map_err(|e| {
61 anyhow::anyhow!(
62 "Failed to invoke Bedrock model {}: {}. This could be due to:\n\
63 1. Model not available in your region\n\
64 2. Insufficient permissions\n\
65 3. Model access not enabled\n\
66 4. Invalid model ID\n\
67 Original error: {}",
68 model_id, e, e
69 )
70 })?;
71
72 let response_body: Value = serde_json::from_slice(&response.body().as_ref())
73 .context("Failed to parse Bedrock response")?;
74
75 info!("Bedrock response for model {}: {}", model_id,
77 serde_json::to_string_pretty(&response_body).unwrap_or_else(|_| "Could not serialize response".to_string()));
78
79 let output = self.extract_output(model_id, &response_body)?;
81 let tokens_used = self.extract_token_usage(&response_body);
82
83 info!("Model {:?} invoked successfully, {:?} tokens used", model_id, tokens_used.clone().unwrap());
84
85 Ok((output, tokens_used))
86 }
87
88 pub async fn invoke_model_with_file_content(
90 &self,
91 model_id: &str,
92 file_content: &str,
93 analysis_prompt: &str,
94 max_tokens: u32
95 ) -> Result<(String, Option<TokenUsage>)> {
96 info!("invoke_model_with_file_content called:");
98 info!(" - model_id: {}", model_id);
99 info!(" - file_content length: {} chars", file_content.len());
100 info!(" - analysis_prompt: {}", analysis_prompt);
101 info!(" - file_content preview (first 200 chars): {}",
102 &file_content.chars().take(200).collect::<String>());
103
104 let combined_prompt = self.format_prompt_with_file_content(file_content, analysis_prompt);
106
107 info!("Combined prompt length: {} chars", combined_prompt.len());
109
110 self.invoke_model(model_id, &combined_prompt, max_tokens).await
112 }
113
114 fn format_prompt_with_file_content(&self, file_content: &str, analysis_prompt: &str) -> String {
116 info!("Formatting prompt with file content length: {} chars", file_content.len());
118
119 if file_content.trim().is_empty() {
120 return format!(
121 "ERROR: No file content provided. Unable to analyze empty content.\n\n\
122 USER REQUEST: {}\n\n\
123 Please ensure the file content is properly loaded before analysis.",
124 analysis_prompt
125 );
126 }
127
128 format!(
129 "You are analyzing the content of a file. The file content is provided below, followed by the analysis request.\n\n\
130 === FILE CONTENT START ===\n\
131 {}\n\
132 === FILE CONTENT END ===\n\n\
133 === ANALYSIS REQUEST ===\n\
134 {}\n\n\
135 === INSTRUCTIONS ===\n\
136 Please analyze the file content provided above and respond to the analysis request. \
137 Base your response ONLY on the content shown between the FILE CONTENT START/END markers. \
138 Do not provide generic guidance - analyze the specific content provided.",
139 file_content, analysis_prompt
140 )
141 }
142
143 fn build_request_body(&self, model_id: &str, prompt: &str, max_tokens: u32) -> Value {
145 match model_id {
146 id if id.starts_with("amazon.nova") => {
148 json!({
149 "messages": [
150 {
151 "role": "user",
152 "content": [
153 {
154 "text": prompt
155 }
156 ]
157 }
158 ],
159 "inferenceConfig": {
160 "max_new_tokens": max_tokens,
161 "temperature": 0.1,
162 "top_p": 0.9
163 }
164 })
165 }
166 id if id.starts_with("anthropic.claude-3") => {
168 json!({
169 "anthropic_version": "bedrock-2023-05-31",
170 "max_tokens": max_tokens,
171 "temperature": 0.1,
172 "top_p": 0.9,
173 "messages": [
174 {
175 "role": "user",
176 "content": prompt
177 }
178 ]
179 })
180 }
181 id if id.starts_with("mistral.") => {
183 json!({
184 "prompt": prompt,
185 "max_tokens": max_tokens,
186 "temperature": 0.1,
187 "top_p": 0.9
188 })
189 }
190 id if id.starts_with("meta.llama") => {
192 json!({
193 "prompt": prompt,
194 "max_gen_len": max_tokens,
195 "temperature": 0.1,
196 "top_p": 0.9
197 })
198 }
199 id if id.starts_with("anthropic.claude") => {
201 json!({
202 "prompt": format!("\n\nHuman: {}\n\nAssistant:", prompt),
203 "max_tokens_to_sample": max_tokens,
204 "temperature": 0.1,
205 "top_p": 0.9,
206 })
207 }
208 id if id.starts_with("amazon.titan") => {
210 json!({
211 "inputText": prompt,
212 "textGenerationConfig": {
213 "maxTokenCount": max_tokens,
214 "temperature": 0.1,
215 "topP": 0.9
216 }
217 })
218 }
219 _ => {
221 json!({
222 "prompt": prompt,
223 "max_tokens": max_tokens,
224 "temperature": 0.1
225 })
226 }
227 }
228 }
229
230 pub fn get_model_pricing() -> HashMap<&'static str, ModelPricing> {
231 let mut pricing = HashMap::new();
232 pricing.insert("amazon.nova-micro-v1:0", ModelPricing { input: 0.00035, output: 0.0014 });
233 pricing.insert("amazon.nova-lite-v1:0", ModelPricing { input: 0.0006, output: 0.0024 });
234 pricing.insert("amazon.nova-pro-v1:0", ModelPricing { input: 0.008, output: 0.032 });
235 pricing.insert("amazon.titan-text-lite-v1", ModelPricing { input: 0.0003, output: 0.0004 });
236 pricing.insert("amazon.titan-text-express-v1", ModelPricing { input: 0.0008, output: 0.0016 });
237 pricing.insert("anthropic.claude-3-haiku-20240307-v1:0", ModelPricing { input: 0.00025, output: 0.00125 });
238 pricing.insert("anthropic.claude-3-sonnet-20240229-v1:0", ModelPricing { input: 0.003, output: 0.015 });
239 pricing.insert("anthropic.claude-3-opus-20240229-v1:0", ModelPricing { input: 0.015, output: 0.075 });
240 pricing.insert("anthropic.claude-3-5-sonnet-20240620-v1:0", ModelPricing { input: 0.003, output: 0.015 });
241 pricing.insert("meta.llama3-8b-instruct-v1:0", ModelPricing { input: 0.0003, output: 0.0006 });
242 pricing.insert("meta.llama3-70b-instruct-v1:0", ModelPricing { input: 0.00265, output: 0.0035 });
243 pricing.insert("mistral.mistral-7b-instruct-v0:2", ModelPricing { input: 0.00015, output: 0.0002 });
244 pricing.insert("mistral.mixtral-8x7b-instruct-v0:1", ModelPricing { input: 0.00045, output: 0.0007 });
245 pricing.insert("mistral.mistral-large-2402-v1:0", ModelPricing { input: 0.004, output: 0.012 });
246 pricing.insert("cohere.command-text-v14", ModelPricing { input: 0.0015, output: 0.002 });
247 pricing.insert("cohere.command-light-text-v14", ModelPricing { input: 0.0003, output: 0.0006 });
248 pricing.insert("ai21.j2-mid-v1", ModelPricing { input: 0.0125, output: 0.0125 });
249 pricing.insert("ai21.j2-ultra-v1", ModelPricing { input: 0.0188, output: 0.0188 });
250 pricing
251 }
252
253 pub fn calculate_cost_estimate(token_usage: Option<TokenUsage>, model_id: &str) -> Option<f64> {
255 token_usage.and_then(|usage| {
256 let pricing_map = ModelRunner::get_model_pricing();
257 pricing_map.get(model_id).map(|pricing| {
258 (usage.input_tokens as f64 * pricing.input / 1000.0) +
259 (usage.output_tokens as f64 * pricing.output / 1000.0)
260 })
261 })
262 }
263
264 fn extract_output(&self, model_id: &str, response_body: &Value) -> Result<String> {
266 let output = match model_id {
267 id if id.starts_with("amazon.nova") => {
269 response_body["output"]["message"]["content"][0]["text"]
270 .as_str()
271 .unwrap_or("No content generated")
272 .to_string()
273 }
274 id if id.starts_with("anthropic.claude-3") => {
276 response_body["content"][0]["text"]
277 .as_str()
278 .unwrap_or("No content generated")
279 .to_string()
280 }
281 id if id.starts_with("mistral.") => {
283 response_body["outputs"][0]["text"]
284 .as_str()
285 .unwrap_or("No content generated")
286 .to_string()
287 }
288 id if id.starts_with("meta.llama") => {
290 response_body["generation"]
291 .as_str()
292 .unwrap_or("No content generated")
293 .to_string()
294 }
295 id if id.starts_with("anthropic.claude") => {
297 response_body["completion"]
298 .as_str()
299 .unwrap_or("No content generated")
300 .to_string()
301 }
302 id if id.starts_with("amazon.titan") => {
304 response_body["results"][0]["outputText"]
305 .as_str()
306 .unwrap_or("No content generated")
307 .to_string()
308 }
309 _ => {
311 response_body["completion"]
312 .as_str()
313 .or_else(|| response_body["text"].as_str())
314 .or_else(|| response_body["outputs"][0]["text"].as_str())
315 .or_else(|| response_body["content"][0]["text"].as_str())
316 .or_else(|| response_body["generation"].as_str())
317 .or_else(|| response_body["results"][0]["outputText"].as_str())
318 .or_else(|| response_body["output"]["message"]["content"][0]["text"].as_str())
319 .unwrap_or("No content generated")
320 .to_string()
321 }
322 };
323
324 if output.trim().is_empty() || output == "No content generated" {
325 info!("Model response body for debugging: {}", serde_json::to_string_pretty(response_body).unwrap_or_else(|_| "Could not serialize response".to_string()));
326 }
327
328 Ok(output)
329 }
330
331 fn extract_token_usage(&self, response_body: &Value) -> Option<TokenUsage> {
333 if let (Some(input), Some(output)) = (
335 response_body["usage"]["inputTokens"].as_u64(),
336 response_body["usage"]["outputTokens"].as_u64(),
337 ) {
338 return Some(TokenUsage {
339 input_tokens: input as u32,
340 output_tokens: output as u32,
341 total_tokens: (input + output) as u32,
342 });
343 }
344
345 if let Some(total) = response_body["usage"]["totalTokens"].as_u64()
347 .or_else(|| response_body["usage"]["total_tokens"].as_u64())
348 {
349 let input_estimate = (total as f64 * 0.7) as u32;
351 let output_estimate = total as u32 - input_estimate;
352 return Some(TokenUsage {
353 input_tokens: input_estimate,
354 output_tokens: output_estimate,
355 total_tokens: total as u32,
356 });
357 }
358
359 if let (Some(input), Some(output)) = (
361 response_body["amazon-bedrock-invocationMetrics"]["inputTokenCount"].as_u64(),
362 response_body["amazon-bedrock-invocationMetrics"]["outputTokenCount"].as_u64(),
363 ) {
364 return Some(TokenUsage {
365 input_tokens: input as u32,
366 output_tokens: output as u32,
367 total_tokens: (input + output) as u32,
368 });
369 }
370
371 None
372 }
373}