Skip to main content

aster/providers/
sagemaker_tgi.rs

1use std::collections::HashMap;
2use std::time::Duration;
3
4use anyhow::Result;
5use async_trait::async_trait;
6use aws_config;
7use aws_sdk_bedrockruntime::config::ProvideCredentials;
8use aws_sdk_sagemakerruntime::Client as SageMakerClient;
9use rmcp::model::Tool;
10use serde_json::{json, Value};
11
12use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
13use super::errors::ProviderError;
14use super::retry::ProviderRetry;
15use super::utils::RequestLog;
16use crate::conversation::message::{Message, MessageContent};
17
18use crate::model::ModelConfig;
19use chrono::Utc;
20use rmcp::model::Role;
21
22pub const SAGEMAKER_TGI_DOC_LINK: &str =
23    "https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html";
24
25pub const SAGEMAKER_TGI_DEFAULT_MODEL: &str = "sagemaker-tgi-endpoint";
26
27#[derive(Debug, serde::Serialize)]
28pub struct SageMakerTgiProvider {
29    #[serde(skip)]
30    sagemaker_client: SageMakerClient,
31    endpoint_name: String,
32    model: ModelConfig,
33    #[serde(skip)]
34    name: String,
35}
36
37impl SageMakerTgiProvider {
38    pub async fn from_env(model: ModelConfig) -> Result<Self> {
39        let config = crate::config::Config::global();
40
41        // Get SageMaker endpoint name (just the name, not full URL)
42        let endpoint_name: String = config.get_param("SAGEMAKER_ENDPOINT_NAME").map_err(|_| {
43            anyhow::anyhow!("SAGEMAKER_ENDPOINT_NAME is required for SageMaker TGI provider")
44        })?;
45
46        // Attempt to load config and secrets to get AWS_ prefixed keys
47        let set_aws_env_vars = |res: Result<HashMap<String, Value>, _>| {
48            if let Ok(map) = res {
49                map.into_iter()
50                    .filter(|(key, _)| key.starts_with("AWS_"))
51                    .filter_map(|(key, value)| value.as_str().map(|s| (key, s.to_string())))
52                    .for_each(|(key, s)| std::env::set_var(key, s));
53            }
54        };
55
56        set_aws_env_vars(config.all_values());
57        set_aws_env_vars(config.all_secrets());
58
59        let aws_config = aws_config::load_from_env().await;
60
61        // Validate credentials
62        aws_config
63            .credentials_provider()
64            .unwrap()
65            .provide_credentials()
66            .await?;
67
68        // Create client with longer timeout for model initialization
69        let timeout_config = aws_config::timeout::TimeoutConfig::builder()
70            .operation_timeout(Duration::from_secs(300)) // 5 minutes for cold starts
71            .build();
72
73        let config_with_timeout = aws_config
74            .into_builder()
75            .timeout_config(timeout_config)
76            .build();
77
78        let sagemaker_client = SageMakerClient::new(&config_with_timeout);
79
80        Ok(Self {
81            sagemaker_client,
82            endpoint_name,
83            model,
84            name: Self::metadata().name,
85        })
86    }
87
88    fn create_tgi_request(&self, system: &str, messages: &[Message]) -> Result<Value> {
89        // Create a simplified prompt for TGI models using recent user and assistant messages.
90        // Uses a minimal system prompt and avoids HTML or tool-related formatting.
91        let mut prompt = String::new();
92
93        // Use a very simple system prompt if provided, but ensure it doesn't contain HTML instructions
94        if !system.is_empty()
95            && !system.contains("Available tools")
96            && system.len() < 200
97            && !system.contains("HTML")
98            && !system.contains("markdown")
99        {
100            prompt.push_str(&format!("System: {}\n\n", system));
101        } else {
102            // Use a minimal system prompt for TGI that explicitly avoids HTML
103            prompt.push_str("System: You are a helpful AI assistant. Provide responses in plain text only. Do not use HTML tags, markup, or formatting.\n\n");
104        }
105
106        // Only include the most recent user messages to avoid overwhelming the model
107        let recent_messages: Vec<_> = messages.iter().rev().take(3).collect();
108        for message in recent_messages.iter().rev() {
109            match &message.role {
110                Role::User => {
111                    prompt.push_str("User: ");
112                    for content in &message.content {
113                        if let MessageContent::Text(text) = content {
114                            prompt.push_str(&text.text);
115                        }
116                    }
117                    prompt.push_str("\n\n");
118                }
119                Role::Assistant => {
120                    prompt.push_str("Assistant: ");
121                    for content in &message.content {
122                        if let MessageContent::Text(text) = content {
123                            // Skip responses that look like tool descriptions or contain HTML
124                            if !text.text.contains("__")
125                                && !text.text.contains("Available tools")
126                                && !text.text.contains("<")
127                            {
128                                prompt.push_str(&text.text);
129                            }
130                        }
131                    }
132                    prompt.push_str("\n\n");
133                }
134            }
135        }
136
137        prompt.push_str("Assistant: ");
138
139        // Skip tool descriptions entirely for TGI models to avoid confusion
140        // TGI models don't support tools natively and including tool descriptions
141        // causes them to mimic that format in their responses
142
143        // Build TGI request with reasonable parameters
144        let request = json!({
145            "inputs": prompt,
146            "parameters": {
147                "max_new_tokens": self.model.max_tokens.unwrap_or(150),
148                "temperature": self.model.temperature.unwrap_or(0.7),
149                "do_sample": true,
150                "return_full_text": false
151            }
152        });
153
154        Ok(request)
155    }
156
157    async fn invoke_endpoint(&self, payload: Value) -> Result<Value, ProviderError> {
158        let body = serde_json::to_string(&payload).map_err(|e| {
159            ProviderError::RequestFailed(format!("Failed to serialize request: {}", e))
160        })?;
161
162        let response = self
163            .sagemaker_client
164            .invoke_endpoint()
165            .endpoint_name(&self.endpoint_name)
166            .content_type("application/json")
167            .body(body.into_bytes().into())
168            .send()
169            .await
170            .map_err(|e| ProviderError::RequestFailed(format!("SageMaker invoke failed: {}", e)))?;
171
172        let response_body = response
173            .body
174            .as_ref()
175            .ok_or_else(|| ProviderError::RequestFailed("Empty response body".to_string()))?;
176        let response_text = std::str::from_utf8(response_body.as_ref()).map_err(|e| {
177            ProviderError::RequestFailed(format!("Failed to decode response: {}", e))
178        })?;
179
180        serde_json::from_str(response_text).map_err(|e| {
181            ProviderError::RequestFailed(format!("Failed to parse response JSON: {}", e))
182        })
183    }
184
185    fn parse_tgi_response(&self, response: Value) -> Result<Message, ProviderError> {
186        // Handle standard TGI response: [{"generated_text": "..."}]
187        let response_array = response
188            .as_array()
189            .ok_or_else(|| ProviderError::RequestFailed("Expected array response".to_string()))?;
190
191        if response_array.is_empty() {
192            return Err(ProviderError::RequestFailed(
193                "Empty response array".to_string(),
194            ));
195        }
196
197        let first_result = &response_array[0];
198        let generated_text = first_result
199            .get("generated_text")
200            .and_then(|v| v.as_str())
201            .ok_or_else(|| {
202                ProviderError::RequestFailed("No generated_text in response".to_string())
203            })?;
204
205        // Strip any HTML tags that might have been generated
206        let clean_text = self.strip_html_tags(generated_text);
207
208        Ok(Message::new(
209            Role::Assistant,
210            Utc::now().timestamp(),
211            vec![MessageContent::text(clean_text)],
212        ))
213    }
214
215    /// Strip HTML tags from text to ensure clean output
216    fn strip_html_tags(&self, text: &str) -> String {
217        // Simple regex-free approach to strip common HTML tags
218        let mut result = text.to_string();
219
220        // Remove common HTML tags like <b>, <i>, <strong>, <em>, etc.
221        let tags_to_remove = [
222            "<b>",
223            "</b>",
224            "<i>",
225            "</i>",
226            "<strong>",
227            "</strong>",
228            "<em>",
229            "</em>",
230            "<u>",
231            "</u>",
232            "<br>",
233            "<br/>",
234            "<p>",
235            "</p>",
236            "<div>",
237            "</div>",
238            "<span>",
239            "</span>",
240        ];
241
242        for tag in &tags_to_remove {
243            result = result.replace(tag, "");
244        }
245
246        // Remove any remaining HTML-like tags using a simple pattern
247        // This is a basic implementation - for production use, consider using a proper HTML parser
248        while let Some(start) = result.find('<') {
249            if let Some(end) = result.get(start..).and_then(|s| s.find('>')) {
250                result.replace_range(start..start + end + 1, "");
251            } else {
252                break;
253            }
254        }
255
256        result.trim().to_string()
257    }
258}
259
260#[async_trait]
261impl Provider for SageMakerTgiProvider {
262    fn metadata() -> ProviderMetadata {
263        ProviderMetadata::new(
264            "sagemaker_tgi",
265            "Amazon SageMaker TGI",
266            "Run Text Generation Inference models through Amazon SageMaker endpoints. Requires AWS credentials and a SageMaker endpoint URL.",
267            SAGEMAKER_TGI_DEFAULT_MODEL,
268            vec![SAGEMAKER_TGI_DEFAULT_MODEL],
269            SAGEMAKER_TGI_DOC_LINK,
270            vec![
271                ConfigKey::new("SAGEMAKER_ENDPOINT_NAME", false, false, None),
272                ConfigKey::new("AWS_REGION", true, false, Some("us-east-1")),
273                ConfigKey::new("AWS_PROFILE", true, false, Some("default")),
274            ],
275        )
276    }
277
278    fn get_name(&self) -> &str {
279        &self.name
280    }
281
282    fn get_model_config(&self) -> ModelConfig {
283        self.model.clone()
284    }
285
286    #[tracing::instrument(
287        skip(self, model_config, system, messages, tools),
288        fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
289    )]
290    async fn complete_with_model(
291        &self,
292        model_config: &ModelConfig,
293        system: &str,
294        messages: &[Message],
295        tools: &[Tool],
296    ) -> Result<(Message, ProviderUsage), ProviderError> {
297        let model_name = &model_config.model_name;
298
299        let request_payload = self.create_tgi_request(system, messages).map_err(|e| {
300            ProviderError::RequestFailed(format!("Failed to create request: {}", e))
301        })?;
302
303        let response = self
304            .with_retry(|| self.invoke_endpoint(request_payload.clone()))
305            .await?;
306
307        let message = self.parse_tgi_response(response)?;
308
309        // TGI doesn't provide usage statistics, so we estimate
310        let usage = Usage::new(
311            Some(0), // Would need to tokenize input to get accurate count
312            Some(0), // Would need to tokenize output to get accurate count
313            Some(0),
314        );
315
316        // Add debug trace
317        let debug_payload = serde_json::json!({
318            "system": system,
319            "messages": messages,
320            "tools": tools
321        });
322        let mut log = RequestLog::start(&self.model, &debug_payload)?;
323        log.write(
324            &serde_json::to_value(&message).unwrap_or_default(),
325            Some(&usage),
326        )?;
327
328        let provider_usage = ProviderUsage::new(model_name.to_string(), usage);
329        Ok((message, provider_usage))
330    }
331}