1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use aws_sdk_bedrockruntime::types::{
5 self as bedrock_types, ContentBlock, ConversationRole, InferenceConfiguration,
6 SystemContentBlock, ToolConfiguration, ToolInputSchema, ToolResultBlock,
7 ToolResultContentBlock, ToolSpecification, ToolUseBlock,
8};
9use aws_smithy_types::Document as SmithyDocument;
10use serde_json::Value;
11use synaptic_core::{
12 AIMessageChunk, ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapticError,
13 TokenUsage, ToolCall, ToolCallChunk, ToolChoice,
14};
15
16#[derive(Debug, Clone)]
22pub struct BedrockConfig {
23 pub model_id: String,
25 pub region: Option<String>,
27 pub max_tokens: Option<i32>,
29 pub temperature: Option<f32>,
31 pub top_p: Option<f32>,
33 pub stop: Option<Vec<String>>,
35}
36
37impl BedrockConfig {
38 pub fn new(model_id: impl Into<String>) -> Self {
40 Self {
41 model_id: model_id.into(),
42 region: None,
43 max_tokens: None,
44 temperature: None,
45 top_p: None,
46 stop: None,
47 }
48 }
49
50 pub fn with_region(mut self, region: impl Into<String>) -> Self {
52 self.region = Some(region.into());
53 self
54 }
55
56 pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
58 self.max_tokens = Some(max_tokens);
59 self
60 }
61
62 pub fn with_temperature(mut self, temperature: f32) -> Self {
64 self.temperature = Some(temperature);
65 self
66 }
67
68 pub fn with_top_p(mut self, top_p: f32) -> Self {
70 self.top_p = Some(top_p);
71 self
72 }
73
74 pub fn with_stop(mut self, stop: Vec<String>) -> Self {
76 self.stop = Some(stop);
77 self
78 }
79}
80
81pub struct BedrockChatModel {
90 config: BedrockConfig,
91 client: aws_sdk_bedrockruntime::Client,
92}
93
94impl BedrockChatModel {
95 pub async fn new(config: BedrockConfig) -> Self {
99 let mut aws_config_loader = aws_config::from_env();
100
101 if let Some(ref region) = config.region {
102 aws_config_loader =
103 aws_config_loader.region(aws_config::Region::new(region.clone()));
104 }
105
106 let aws_config = aws_config_loader.load().await;
107 let client = aws_sdk_bedrockruntime::Client::new(&aws_config);
108
109 Self { config, client }
110 }
111
112 pub fn from_client(config: BedrockConfig, client: aws_sdk_bedrockruntime::Client) -> Self {
114 Self { config, client }
115 }
116
117 fn build_inference_config(&self) -> Option<InferenceConfiguration> {
119 let has_any = self.config.max_tokens.is_some()
120 || self.config.temperature.is_some()
121 || self.config.top_p.is_some()
122 || self.config.stop.is_some();
123
124 if !has_any {
125 return None;
126 }
127
128 let mut builder = InferenceConfiguration::builder();
129
130 if let Some(max_tokens) = self.config.max_tokens {
131 builder = builder.max_tokens(max_tokens);
132 }
133 if let Some(temperature) = self.config.temperature {
134 builder = builder.temperature(temperature);
135 }
136 if let Some(top_p) = self.config.top_p {
137 builder = builder.top_p(top_p);
138 }
139 if let Some(ref stop) = self.config.stop {
140 for s in stop {
141 builder = builder.stop_sequences(s.clone());
142 }
143 }
144
145 Some(builder.build())
146 }
147
148 fn build_tool_config(
150 &self,
151 request: &ChatRequest,
152 ) -> Option<ToolConfiguration> {
153 if request.tools.is_empty() {
154 return None;
155 }
156
157 let tools: Vec<bedrock_types::Tool> = request
158 .tools
159 .iter()
160 .map(|td| {
161 let spec = ToolSpecification::builder()
162 .name(&td.name)
163 .description(&td.description)
164 .input_schema(ToolInputSchema::Json(json_value_to_document(&td.parameters)))
165 .build()
166 .expect("tool specification build should not fail");
167
168 bedrock_types::Tool::ToolSpec(spec)
169 })
170 .collect();
171
172 let mut builder = ToolConfiguration::builder();
173 for tool in tools {
174 builder = builder.tools(tool);
175 }
176
177 if let Some(ref choice) = request.tool_choice {
178 let bedrock_choice = match choice {
179 ToolChoice::Auto => {
180 bedrock_types::ToolChoice::Auto(
181 bedrock_types::AutoToolChoice::builder().build(),
182 )
183 }
184 ToolChoice::Required => {
185 bedrock_types::ToolChoice::Any(
186 bedrock_types::AnyToolChoice::builder().build(),
187 )
188 }
189 ToolChoice::None => {
190 bedrock_types::ToolChoice::Auto(
193 bedrock_types::AutoToolChoice::builder().build(),
194 )
195 }
196 ToolChoice::Specific(name) => {
197 bedrock_types::ToolChoice::Tool(
198 bedrock_types::SpecificToolChoice::builder()
199 .name(name)
200 .build()
201 .expect("specific tool choice build should not fail"),
202 )
203 }
204 };
205 builder = builder.tool_choice(bedrock_choice);
206 }
207
208 Some(builder.build().expect("tool configuration build should not fail"))
209 }
210}
211
212#[async_trait]
213impl ChatModel for BedrockChatModel {
214 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
215 let (system_blocks, messages) = convert_messages(&request.messages);
216
217 let mut converse = self
218 .client
219 .converse()
220 .model_id(&self.config.model_id);
221
222 for block in system_blocks {
224 converse = converse.system(block);
225 }
226
227 for msg in messages {
229 converse = converse.messages(msg);
230 }
231
232 if let Some(inference_config) = self.build_inference_config() {
234 converse = converse.inference_config(inference_config);
235 }
236
237 if let Some(tool_config) = self.build_tool_config(&request) {
239 converse = converse.tool_config(tool_config);
240 }
241
242 let output = converse
243 .send()
244 .await
245 .map_err(|e| SynapticError::Model(format!("Bedrock Converse API error: {e}")))?;
246
247 let usage = output.usage().map(|u| TokenUsage {
249 input_tokens: u.input_tokens() as u32,
250 output_tokens: u.output_tokens() as u32,
251 total_tokens: u.total_tokens() as u32,
252 input_details: None,
253 output_details: None,
254 });
255
256 let message = match output.output() {
258 Some(bedrock_types::ConverseOutput::Message(msg)) => {
259 parse_bedrock_message(msg)
260 }
261 _ => Message::ai(""),
262 };
263
264 Ok(ChatResponse { message, usage })
265 }
266
267 fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
268 Box::pin(async_stream::stream! {
269 let (system_blocks, messages) = convert_messages(&request.messages);
270
271 let mut converse_stream = self
272 .client
273 .converse_stream()
274 .model_id(&self.config.model_id);
275
276 for block in system_blocks {
277 converse_stream = converse_stream.system(block);
278 }
279
280 for msg in messages {
281 converse_stream = converse_stream.messages(msg);
282 }
283
284 if let Some(inference_config) = self.build_inference_config() {
285 converse_stream = converse_stream.inference_config(inference_config);
286 }
287
288 if let Some(tool_config) = self.build_tool_config(&request) {
289 converse_stream = converse_stream.tool_config(tool_config);
290 }
291
292 let output = match converse_stream.send().await {
293 Ok(o) => o,
294 Err(e) => {
295 yield Err(SynapticError::Model(format!(
296 "Bedrock ConverseStream API error: {e}"
297 )));
298 return;
299 }
300 };
301
302 let mut stream = output.stream;
303
304 let mut current_tool_id: Option<String> = None;
306 let mut current_tool_name: Option<String> = None;
307 let mut current_tool_input: String = String::new();
308
309 loop {
310 match stream.recv().await {
311 Ok(Some(event)) => {
312 match event {
313 bedrock_types::ConverseStreamOutput::ContentBlockStart(start_event) => {
314 if let Some(bedrock_types::ContentBlockStart::ToolUse(tool_start)) = start_event.start() {
315 current_tool_id = Some(tool_start.tool_use_id().to_string());
316 current_tool_name = Some(tool_start.name().to_string());
317 current_tool_input.clear();
318
319 yield Ok(AIMessageChunk {
320 tool_call_chunks: vec![ToolCallChunk {
321 id: Some(tool_start.tool_use_id().to_string()),
322 name: Some(tool_start.name().to_string()),
323 arguments: None,
324 index: Some(start_event.content_block_index() as usize),
325 }],
326 ..Default::default()
327 });
328 }
329 }
330 bedrock_types::ConverseStreamOutput::ContentBlockDelta(delta_event) => {
331 if let Some(delta) = delta_event.delta() {
332 match delta {
333 bedrock_types::ContentBlockDelta::Text(text) => {
334 yield Ok(AIMessageChunk {
335 content: text.to_string(),
336 ..Default::default()
337 });
338 }
339 bedrock_types::ContentBlockDelta::ToolUse(tool_delta) => {
340 let input_fragment = tool_delta.input();
341 current_tool_input.push_str(input_fragment);
342
343 yield Ok(AIMessageChunk {
344 tool_call_chunks: vec![ToolCallChunk {
345 id: current_tool_id.clone(),
346 name: current_tool_name.clone(),
347 arguments: Some(input_fragment.to_string()),
348 index: Some(delta_event.content_block_index() as usize),
349 }],
350 ..Default::default()
351 });
352 }
353 _ => { }
354 }
355 }
356 }
357 bedrock_types::ConverseStreamOutput::ContentBlockStop(_) => {
358 if let (Some(id), Some(name)) = (current_tool_id.take(), current_tool_name.take()) {
360 let arguments: Value = serde_json::from_str(¤t_tool_input)
361 .unwrap_or(Value::Object(Default::default()));
362 current_tool_input.clear();
363
364 yield Ok(AIMessageChunk {
365 tool_calls: vec![ToolCall {
366 id,
367 name,
368 arguments,
369 }],
370 ..Default::default()
371 });
372 }
373 }
374 bedrock_types::ConverseStreamOutput::Metadata(meta) => {
375 if let Some(u) = meta.usage() {
376 yield Ok(AIMessageChunk {
377 usage: Some(TokenUsage {
378 input_tokens: u.input_tokens() as u32,
379 output_tokens: u.output_tokens() as u32,
380 total_tokens: u.total_tokens() as u32,
381 input_details: None,
382 output_details: None,
383 }),
384 ..Default::default()
385 });
386 }
387 }
388 _ => { }
389 }
390 }
391 Ok(None) => break,
392 Err(e) => {
393 yield Err(SynapticError::Model(format!(
394 "Bedrock stream error: {e}"
395 )));
396 break;
397 }
398 }
399 }
400 })
401 }
402}
403
404fn convert_messages(
413 messages: &[Message],
414) -> (Vec<SystemContentBlock>, Vec<bedrock_types::Message>) {
415 let mut system_blocks = Vec::new();
416 let mut bedrock_messages: Vec<bedrock_types::Message> = Vec::new();
417
418 for msg in messages {
419 match msg {
420 Message::System { content, .. } => {
421 system_blocks.push(SystemContentBlock::Text(content.clone()));
422 }
423 Message::Human { content, .. } => {
424 let bedrock_msg = bedrock_types::Message::builder()
425 .role(ConversationRole::User)
426 .content(ContentBlock::Text(content.clone()))
427 .build()
428 .expect("message build should not fail");
429 bedrock_messages.push(bedrock_msg);
430 }
431 Message::AI {
432 content,
433 tool_calls,
434 ..
435 } => {
436 let mut blocks: Vec<ContentBlock> = Vec::new();
437
438 if !content.is_empty() {
439 blocks.push(ContentBlock::Text(content.clone()));
440 }
441
442 for tc in tool_calls {
443 let tool_use = ToolUseBlock::builder()
444 .tool_use_id(&tc.id)
445 .name(&tc.name)
446 .input(json_value_to_document(&tc.arguments))
447 .build()
448 .expect("tool use block build should not fail");
449 blocks.push(ContentBlock::ToolUse(tool_use));
450 }
451
452 if blocks.is_empty() {
454 blocks.push(ContentBlock::Text(String::new()));
455 }
456
457 let bedrock_msg = bedrock_types::Message::builder()
458 .role(ConversationRole::Assistant)
459 .set_content(Some(blocks))
460 .build()
461 .expect("message build should not fail");
462 bedrock_messages.push(bedrock_msg);
463 }
464 Message::Tool {
465 content,
466 tool_call_id,
467 ..
468 } => {
469 let tool_result = ToolResultBlock::builder()
470 .tool_use_id(tool_call_id)
471 .content(ToolResultContentBlock::Text(content.clone()))
472 .build()
473 .expect("tool result block build should not fail");
474
475 let bedrock_msg = bedrock_types::Message::builder()
476 .role(ConversationRole::User)
477 .content(ContentBlock::ToolResult(tool_result))
478 .build()
479 .expect("message build should not fail");
480 bedrock_messages.push(bedrock_msg);
481 }
482 Message::Chat { content, .. } => {
483 let bedrock_msg = bedrock_types::Message::builder()
485 .role(ConversationRole::User)
486 .content(ContentBlock::Text(content.clone()))
487 .build()
488 .expect("message build should not fail");
489 bedrock_messages.push(bedrock_msg);
490 }
491 Message::Remove { .. } => { }
492 }
493 }
494
495 (system_blocks, bedrock_messages)
496}
497
498fn parse_bedrock_message(msg: &bedrock_types::Message) -> Message {
500 let mut text_parts: Vec<String> = Vec::new();
501 let mut tool_calls: Vec<ToolCall> = Vec::new();
502
503 for block in msg.content() {
504 match block {
505 ContentBlock::Text(text) => {
506 text_parts.push(text.clone());
507 }
508 ContentBlock::ToolUse(tool_use) => {
509 tool_calls.push(ToolCall {
510 id: tool_use.tool_use_id().to_string(),
511 name: tool_use.name().to_string(),
512 arguments: document_to_json_value(tool_use.input()),
513 });
514 }
515 _ => { }
516 }
517 }
518
519 let content = text_parts.join("");
520
521 if tool_calls.is_empty() {
522 Message::ai(content)
523 } else {
524 Message::ai_with_tool_calls(content, tool_calls)
525 }
526}
527
528pub(crate) fn json_value_to_document(value: &Value) -> SmithyDocument {
534 match value {
535 Value::Null => SmithyDocument::Null,
536 Value::Bool(b) => SmithyDocument::Bool(*b),
537 Value::Number(n) => {
538 if let Some(i) = n.as_i64() {
539 SmithyDocument::Number(aws_smithy_types::Number::NegInt(i))
540 } else if let Some(u) = n.as_u64() {
541 SmithyDocument::Number(aws_smithy_types::Number::PosInt(u))
542 } else if let Some(f) = n.as_f64() {
543 SmithyDocument::Number(aws_smithy_types::Number::Float(f))
544 } else {
545 SmithyDocument::Null
546 }
547 }
548 Value::String(s) => SmithyDocument::String(s.clone()),
549 Value::Array(arr) => {
550 SmithyDocument::Array(arr.iter().map(json_value_to_document).collect())
551 }
552 Value::Object(obj) => {
553 let map: HashMap<String, SmithyDocument> = obj
554 .iter()
555 .map(|(k, v)| (k.clone(), json_value_to_document(v)))
556 .collect();
557 SmithyDocument::Object(map)
558 }
559 }
560}
561
562pub(crate) fn document_to_json_value(doc: &SmithyDocument) -> Value {
564 match doc {
565 SmithyDocument::Null => Value::Null,
566 SmithyDocument::Bool(b) => Value::Bool(*b),
567 SmithyDocument::Number(n) => match *n {
568 aws_smithy_types::Number::PosInt(u) => {
569 serde_json::json!(u)
570 }
571 aws_smithy_types::Number::NegInt(i) => {
572 serde_json::json!(i)
573 }
574 aws_smithy_types::Number::Float(f) => {
575 serde_json::json!(f)
576 }
577 },
578 SmithyDocument::String(s) => Value::String(s.clone()),
579 SmithyDocument::Array(arr) => {
580 Value::Array(arr.iter().map(document_to_json_value).collect())
581 }
582 SmithyDocument::Object(obj) => {
583 let map: serde_json::Map<String, Value> = obj
584 .iter()
585 .map(|(k, v)| (k.clone(), document_to_json_value(v)))
586 .collect();
587 Value::Object(map)
588 }
589 }
590}
591
592#[cfg(test)]
593mod tests {
594 use super::*;
595
596 #[test]
597 fn json_value_to_document_round_trip() {
598 let original = serde_json::json!({
599 "type": "object",
600 "properties": {
601 "name": {"type": "string"},
602 "age": {"type": "integer"}
603 },
604 "required": ["name"]
605 });
606
607 let doc = json_value_to_document(&original);
608 let back = document_to_json_value(&doc);
609 assert_eq!(original, back);
610 }
611
612 #[test]
613 fn json_value_to_document_primitives() {
614 assert!(matches!(
615 json_value_to_document(&Value::Null),
616 SmithyDocument::Null
617 ));
618 assert!(matches!(
619 json_value_to_document(&Value::Bool(true)),
620 SmithyDocument::Bool(true)
621 ));
622 assert!(matches!(
623 json_value_to_document(&serde_json::json!("hello")),
624 SmithyDocument::String(_)
625 ));
626 }
627
628 #[test]
629 fn convert_system_messages() {
630 let messages = vec![
631 Message::system("You are a helpful assistant."),
632 Message::human("Hello!"),
633 ];
634
635 let (system_blocks, bedrock_messages) = convert_messages(&messages);
636 assert_eq!(system_blocks.len(), 1);
637 assert_eq!(bedrock_messages.len(), 1);
638 }
639
640 #[test]
641 fn convert_tool_messages() {
642 let messages = vec![
643 Message::human("What is the weather?"),
644 Message::ai_with_tool_calls(
645 "",
646 vec![ToolCall {
647 id: "tc_1".to_string(),
648 name: "get_weather".to_string(),
649 arguments: serde_json::json!({"city": "NYC"}),
650 }],
651 ),
652 Message::tool("Sunny, 72F", "tc_1"),
653 ];
654
655 let (system_blocks, bedrock_messages) = convert_messages(&messages);
656 assert!(system_blocks.is_empty());
657 assert_eq!(bedrock_messages.len(), 3);
658
659 assert_eq!(*bedrock_messages[0].role(), ConversationRole::User);
661 assert_eq!(*bedrock_messages[1].role(), ConversationRole::Assistant);
663 assert_eq!(*bedrock_messages[2].role(), ConversationRole::User);
665 }
666
667 #[test]
668 fn convert_remove_messages_are_skipped() {
669 let messages = vec![
670 Message::human("Hi"),
671 Message::remove("some-id"),
672 Message::ai("Hello!"),
673 ];
674
675 let (_, bedrock_messages) = convert_messages(&messages);
676 assert_eq!(bedrock_messages.len(), 2);
677 }
678
679 #[test]
680 fn parse_text_only_message() {
681 let msg = bedrock_types::Message::builder()
682 .role(ConversationRole::Assistant)
683 .content(ContentBlock::Text("Hello world".to_string()))
684 .build()
685 .unwrap();
686
687 let parsed = parse_bedrock_message(&msg);
688 assert!(parsed.is_ai());
689 assert_eq!(parsed.content(), "Hello world");
690 assert!(parsed.tool_calls().is_empty());
691 }
692
693 #[test]
694 fn parse_message_with_tool_use() {
695 let tool_use = ToolUseBlock::builder()
696 .tool_use_id("tc_1")
697 .name("calculator")
698 .input(json_value_to_document(&serde_json::json!({"expr": "1+1"})))
699 .build()
700 .unwrap();
701
702 let msg = bedrock_types::Message::builder()
703 .role(ConversationRole::Assistant)
704 .content(ContentBlock::Text("Let me calculate.".to_string()))
705 .content(ContentBlock::ToolUse(tool_use))
706 .build()
707 .unwrap();
708
709 let parsed = parse_bedrock_message(&msg);
710 assert!(parsed.is_ai());
711 assert_eq!(parsed.content(), "Let me calculate.");
712 assert_eq!(parsed.tool_calls().len(), 1);
713 assert_eq!(parsed.tool_calls()[0].id, "tc_1");
714 assert_eq!(parsed.tool_calls()[0].name, "calculator");
715 assert_eq!(parsed.tool_calls()[0].arguments, serde_json::json!({"expr": "1+1"}));
716 }
717}