1use std::collections::HashMap;
2use std::time::Duration;
3
4use anyhow::Result;
5use async_trait::async_trait;
6use aws_config;
7use aws_sdk_bedrockruntime::config::ProvideCredentials;
8use aws_sdk_sagemakerruntime::Client as SageMakerClient;
9use rmcp::model::Tool;
10use serde_json::{json, Value};
11
12use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
13use super::errors::ProviderError;
14use super::retry::ProviderRetry;
15use super::utils::RequestLog;
16use crate::conversation::message::{Message, MessageContent};
17
18use crate::model::ModelConfig;
19use chrono::Utc;
20use rmcp::model::Role;
21
22pub const SAGEMAKER_TGI_DOC_LINK: &str =
23 "https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html";
24
25pub const SAGEMAKER_TGI_DEFAULT_MODEL: &str = "sagemaker-tgi-endpoint";
26
27#[derive(Debug, serde::Serialize)]
28pub struct SageMakerTgiProvider {
29 #[serde(skip)]
30 sagemaker_client: SageMakerClient,
31 endpoint_name: String,
32 model: ModelConfig,
33 #[serde(skip)]
34 name: String,
35}
36
37impl SageMakerTgiProvider {
38 pub async fn from_env(model: ModelConfig) -> Result<Self> {
39 let config = crate::config::Config::global();
40
41 let endpoint_name: String = config.get_param("SAGEMAKER_ENDPOINT_NAME").map_err(|_| {
43 anyhow::anyhow!("SAGEMAKER_ENDPOINT_NAME is required for SageMaker TGI provider")
44 })?;
45
46 let set_aws_env_vars = |res: Result<HashMap<String, Value>, _>| {
48 if let Ok(map) = res {
49 map.into_iter()
50 .filter(|(key, _)| key.starts_with("AWS_"))
51 .filter_map(|(key, value)| value.as_str().map(|s| (key, s.to_string())))
52 .for_each(|(key, s)| std::env::set_var(key, s));
53 }
54 };
55
56 set_aws_env_vars(config.all_values());
57 set_aws_env_vars(config.all_secrets());
58
59 let aws_config = aws_config::load_from_env().await;
60
61 aws_config
63 .credentials_provider()
64 .unwrap()
65 .provide_credentials()
66 .await?;
67
68 let timeout_config = aws_config::timeout::TimeoutConfig::builder()
70 .operation_timeout(Duration::from_secs(300)) .build();
72
73 let config_with_timeout = aws_config
74 .into_builder()
75 .timeout_config(timeout_config)
76 .build();
77
78 let sagemaker_client = SageMakerClient::new(&config_with_timeout);
79
80 Ok(Self {
81 sagemaker_client,
82 endpoint_name,
83 model,
84 name: Self::metadata().name,
85 })
86 }
87
88 fn create_tgi_request(&self, system: &str, messages: &[Message]) -> Result<Value> {
89 let mut prompt = String::new();
92
93 if !system.is_empty()
95 && !system.contains("Available tools")
96 && system.len() < 200
97 && !system.contains("HTML")
98 && !system.contains("markdown")
99 {
100 prompt.push_str(&format!("System: {}\n\n", system));
101 } else {
102 prompt.push_str("System: You are a helpful AI assistant. Provide responses in plain text only. Do not use HTML tags, markup, or formatting.\n\n");
104 }
105
106 let recent_messages: Vec<_> = messages.iter().rev().take(3).collect();
108 for message in recent_messages.iter().rev() {
109 match &message.role {
110 Role::User => {
111 prompt.push_str("User: ");
112 for content in &message.content {
113 if let MessageContent::Text(text) = content {
114 prompt.push_str(&text.text);
115 }
116 }
117 prompt.push_str("\n\n");
118 }
119 Role::Assistant => {
120 prompt.push_str("Assistant: ");
121 for content in &message.content {
122 if let MessageContent::Text(text) = content {
123 if !text.text.contains("__")
125 && !text.text.contains("Available tools")
126 && !text.text.contains("<")
127 {
128 prompt.push_str(&text.text);
129 }
130 }
131 }
132 prompt.push_str("\n\n");
133 }
134 }
135 }
136
137 prompt.push_str("Assistant: ");
138
139 let request = json!({
145 "inputs": prompt,
146 "parameters": {
147 "max_new_tokens": self.model.max_tokens.unwrap_or(150),
148 "temperature": self.model.temperature.unwrap_or(0.7),
149 "do_sample": true,
150 "return_full_text": false
151 }
152 });
153
154 Ok(request)
155 }
156
157 async fn invoke_endpoint(&self, payload: Value) -> Result<Value, ProviderError> {
158 let body = serde_json::to_string(&payload).map_err(|e| {
159 ProviderError::RequestFailed(format!("Failed to serialize request: {}", e))
160 })?;
161
162 let response = self
163 .sagemaker_client
164 .invoke_endpoint()
165 .endpoint_name(&self.endpoint_name)
166 .content_type("application/json")
167 .body(body.into_bytes().into())
168 .send()
169 .await
170 .map_err(|e| ProviderError::RequestFailed(format!("SageMaker invoke failed: {}", e)))?;
171
172 let response_body = response
173 .body
174 .as_ref()
175 .ok_or_else(|| ProviderError::RequestFailed("Empty response body".to_string()))?;
176 let response_text = std::str::from_utf8(response_body.as_ref()).map_err(|e| {
177 ProviderError::RequestFailed(format!("Failed to decode response: {}", e))
178 })?;
179
180 serde_json::from_str(response_text).map_err(|e| {
181 ProviderError::RequestFailed(format!("Failed to parse response JSON: {}", e))
182 })
183 }
184
185 fn parse_tgi_response(&self, response: Value) -> Result<Message, ProviderError> {
186 let response_array = response
188 .as_array()
189 .ok_or_else(|| ProviderError::RequestFailed("Expected array response".to_string()))?;
190
191 if response_array.is_empty() {
192 return Err(ProviderError::RequestFailed(
193 "Empty response array".to_string(),
194 ));
195 }
196
197 let first_result = &response_array[0];
198 let generated_text = first_result
199 .get("generated_text")
200 .and_then(|v| v.as_str())
201 .ok_or_else(|| {
202 ProviderError::RequestFailed("No generated_text in response".to_string())
203 })?;
204
205 let clean_text = self.strip_html_tags(generated_text);
207
208 Ok(Message::new(
209 Role::Assistant,
210 Utc::now().timestamp(),
211 vec![MessageContent::text(clean_text)],
212 ))
213 }
214
215 fn strip_html_tags(&self, text: &str) -> String {
217 let mut result = text.to_string();
219
220 let tags_to_remove = [
222 "<b>",
223 "</b>",
224 "<i>",
225 "</i>",
226 "<strong>",
227 "</strong>",
228 "<em>",
229 "</em>",
230 "<u>",
231 "</u>",
232 "<br>",
233 "<br/>",
234 "<p>",
235 "</p>",
236 "<div>",
237 "</div>",
238 "<span>",
239 "</span>",
240 ];
241
242 for tag in &tags_to_remove {
243 result = result.replace(tag, "");
244 }
245
246 while let Some(start) = result.find('<') {
249 if let Some(end) = result.get(start..).and_then(|s| s.find('>')) {
250 result.replace_range(start..start + end + 1, "");
251 } else {
252 break;
253 }
254 }
255
256 result.trim().to_string()
257 }
258}
259
260#[async_trait]
261impl Provider for SageMakerTgiProvider {
262 fn metadata() -> ProviderMetadata {
263 ProviderMetadata::new(
264 "sagemaker_tgi",
265 "Amazon SageMaker TGI",
266 "Run Text Generation Inference models through Amazon SageMaker endpoints. Requires AWS credentials and a SageMaker endpoint URL.",
267 SAGEMAKER_TGI_DEFAULT_MODEL,
268 vec![SAGEMAKER_TGI_DEFAULT_MODEL],
269 SAGEMAKER_TGI_DOC_LINK,
270 vec![
271 ConfigKey::new("SAGEMAKER_ENDPOINT_NAME", false, false, None),
272 ConfigKey::new("AWS_REGION", true, false, Some("us-east-1")),
273 ConfigKey::new("AWS_PROFILE", true, false, Some("default")),
274 ],
275 )
276 }
277
278 fn get_name(&self) -> &str {
279 &self.name
280 }
281
282 fn get_model_config(&self) -> ModelConfig {
283 self.model.clone()
284 }
285
286 #[tracing::instrument(
287 skip(self, model_config, system, messages, tools),
288 fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
289 )]
290 async fn complete_with_model(
291 &self,
292 model_config: &ModelConfig,
293 system: &str,
294 messages: &[Message],
295 tools: &[Tool],
296 ) -> Result<(Message, ProviderUsage), ProviderError> {
297 let model_name = &model_config.model_name;
298
299 let request_payload = self.create_tgi_request(system, messages).map_err(|e| {
300 ProviderError::RequestFailed(format!("Failed to create request: {}", e))
301 })?;
302
303 let response = self
304 .with_retry(|| self.invoke_endpoint(request_payload.clone()))
305 .await?;
306
307 let message = self.parse_tgi_response(response)?;
308
309 let usage = Usage::new(
311 Some(0), Some(0), Some(0),
314 );
315
316 let debug_payload = serde_json::json!({
318 "system": system,
319 "messages": messages,
320 "tools": tools
321 });
322 let mut log = RequestLog::start(&self.model, &debug_payload)?;
323 log.write(
324 &serde_json::to_value(&message).unwrap_or_default(),
325 Some(&usage),
326 )?;
327
328 let provider_usage = ProviderUsage::new(model_name.to_string(), usage);
329 Ok((message, provider_usage))
330 }
331}