use crate::types::json::AwsDocument;
use crate::types::message::RigMessage;
use aws_sdk_bedrockruntime::types as aws_bedrock;
use aws_sdk_bedrockruntime::types::{
CachePointBlock, CachePointType, InferenceConfiguration, SystemContentBlock, Tool,
ToolConfiguration, ToolInputSchema, ToolSpecification,
};
use rig_core::OneOrMany;
use rig_core::completion::{CompletionError, Message};
use rig_core::message::{DocumentMediaType, UserContent};
pub struct AwsCompletionRequest {
pub inner: rig_core::completion::CompletionRequest,
pub prompt_caching: bool,
}
fn cache_point_block() -> Result<CachePointBlock, CompletionError> {
CachePointBlock::builder()
.r#type(CachePointType::Default)
.build()
.map_err(|e| CompletionError::RequestError(e.into()))
}
impl AwsCompletionRequest {
pub fn additional_params(&self) -> Option<aws_smithy_types::Document> {
self.inner
.additional_params
.to_owned()
.map(|params| params.into())
.map(|doc: AwsDocument| doc.0)
}
pub fn inference_config(&self) -> Option<InferenceConfiguration> {
let mut inference_configuration = InferenceConfiguration::builder();
if let Some(temperature) = &self.inner.temperature {
inference_configuration =
inference_configuration.set_temperature(Some(*temperature as f32));
}
if let Some(max_tokens) = &self.inner.max_tokens {
inference_configuration =
inference_configuration.set_max_tokens(Some(*max_tokens as i32));
}
Some(inference_configuration.build())
}
pub fn tools_config(&self) -> Result<Option<ToolConfiguration>, CompletionError> {
let mut tools = vec![];
for tool_definition in self.inner.tools.iter() {
let doc: AwsDocument = tool_definition.parameters.clone().into();
let schema = ToolInputSchema::Json(doc.0);
let tool = Tool::ToolSpec(
ToolSpecification::builder()
.name(tool_definition.name.clone())
.set_description(Some(tool_definition.description.clone()))
.set_input_schema(Some(schema))
.build()
.map_err(|e| CompletionError::RequestError(e.into()))?,
);
tools.push(tool);
}
if !tools.is_empty() {
use aws_sdk_bedrockruntime::types as aws_bedrock;
let tool_choice = self
.inner
.tool_choice
.as_ref()
.map(|choice| match choice {
rig_core::message::ToolChoice::Auto => Ok(Some(aws_bedrock::ToolChoice::Auto(
aws_bedrock::AutoToolChoice::builder().build(),
))),
rig_core::message::ToolChoice::Required => Ok(Some(
aws_bedrock::ToolChoice::Any(aws_bedrock::AnyToolChoice::builder().build()),
)),
rig_core::message::ToolChoice::None => Ok(None),
rig_core::message::ToolChoice::Specific { function_names } => function_names
.first()
.map(|name| {
aws_bedrock::SpecificToolChoice::builder()
.name(name.clone())
.build()
.map(aws_bedrock::ToolChoice::Tool)
.map(Some)
.map_err(|e| CompletionError::RequestError(e.into()))
})
.transpose()
.map(Option::flatten),
})
.transpose()?
.flatten();
let config = ToolConfiguration::builder()
.set_tools(Some(tools))
.set_tool_choice(tool_choice)
.build()
.map_err(|e| CompletionError::RequestError(e.into()))?;
Ok(Some(config))
} else {
Ok(None)
}
}
pub fn output_config(&self) -> Result<Option<aws_bedrock::OutputConfig>, CompletionError> {
let Some(schema) = self.inner.output_schema.as_ref() else {
return Ok(None);
};
let schema_name = self
.inner
.output_schema_name()
.unwrap_or_else(|| "response_schema".to_string());
let schema_json = serde_json::to_string(&schema.clone().to_value())
.map_err(|e| CompletionError::RequestError(e.into()))?;
let json_schema_def = aws_bedrock::JsonSchemaDefinition::builder()
.schema(schema_json)
.name(schema_name)
.build()
.map_err(|e| CompletionError::RequestError(e.into()))?;
let text_format = aws_bedrock::OutputFormat::builder()
.r#type(aws_bedrock::OutputFormatType::JsonSchema)
.structure(aws_bedrock::OutputFormatStructure::JsonSchema(
json_schema_def,
))
.build()
.map_err(|e| CompletionError::RequestError(e.into()))?;
Ok(Some(
aws_bedrock::OutputConfig::builder()
.text_format(text_format)
.build(),
))
}
pub fn system_prompt(&self) -> Result<Option<Vec<SystemContentBlock>>, CompletionError> {
let mut system_blocks = Vec::new();
if let Some(system_prompt) = self.inner.preamble.to_owned()
&& !system_prompt.is_empty()
{
system_blocks.push(SystemContentBlock::Text(system_prompt));
}
for message in self.inner.chat_history.iter() {
if let Message::System { content } = message
&& !content.is_empty()
{
system_blocks.push(SystemContentBlock::Text(content.clone()));
}
}
if system_blocks.is_empty() {
Ok(None)
} else {
if self.prompt_caching {
system_blocks.push(SystemContentBlock::CachePoint(cache_point_block()?));
}
Ok(Some(system_blocks))
}
}
pub fn messages(&self) -> Result<Vec<aws_bedrock::Message>, CompletionError> {
let mut full_history: Vec<Message> = Vec::new();
if !self.inner.documents.is_empty() {
let messages = self
.inner
.documents
.iter()
.map(|doc| doc.to_string())
.collect::<Vec<_>>()
.join(" | ");
let content = OneOrMany::one(UserContent::document(
messages,
Some(DocumentMediaType::TXT),
));
full_history.push(Message::User { content });
}
self.inner.chat_history.iter().for_each(|message| {
if !matches!(message, Message::System { .. }) {
full_history.push(message.clone());
}
});
let mut messages: Vec<aws_bedrock::Message> = full_history
.into_iter()
.map(|message| RigMessage(message).try_into())
.collect::<Result<Vec<aws_bedrock::Message>, _>>()?;
let has_reasoning = self.inner.chat_history.iter().any(|message| match message {
Message::Assistant { content, .. } => content
.iter()
.any(|c| matches!(c, rig_core::completion::AssistantContent::Reasoning(_))),
_ => false,
});
if self.prompt_caching
&& !has_reasoning
&& let Some(last_msg) = messages.last_mut()
{
let mut content = last_msg.content.clone();
content.push(aws_bedrock::ContentBlock::CachePoint(cache_point_block()?));
*last_msg = aws_bedrock::Message::builder()
.role(last_msg.role.clone())
.set_content(Some(content))
.build()
.map_err(|e| CompletionError::RequestError(e.into()))?;
}
Ok(messages)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rig_core::OneOrMany;
use rig_core::completion::{CompletionRequest, ToolDefinition};
use rig_core::message::{Message, Text, ToolChoice, UserContent};
fn minimal_request() -> CompletionRequest {
CompletionRequest {
model: None,
preamble: None,
chat_history: OneOrMany::one(Message::User {
content: OneOrMany::one(UserContent::Text(Text {
text: "test".to_string(),
})),
}),
documents: vec![],
tools: vec![],
temperature: None,
max_tokens: None,
tool_choice: None,
additional_params: None,
output_schema: None,
}
}
fn aws_request(request: CompletionRequest, prompt_caching: bool) -> AwsCompletionRequest {
AwsCompletionRequest {
inner: request,
prompt_caching,
}
}
#[test]
fn test_tool_choice_auto_conversion() {
let request = CompletionRequest {
model: None,
tool_choice: Some(ToolChoice::Auto),
tools: vec![ToolDefinition {
name: "test_tool".to_string(),
description: "A test tool".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {}
}),
}],
..minimal_request()
};
let aws_request = aws_request(request, false);
let tool_config = aws_request
.tools_config()
.expect("Should build tool config");
assert!(tool_config.is_some());
let config = tool_config.unwrap();
assert!(config.tool_choice().is_some());
assert!(matches!(
config.tool_choice().unwrap(),
aws_bedrock::ToolChoice::Auto(_)
));
}
#[test]
fn test_tool_choice_required_conversion() {
let request = CompletionRequest {
model: None,
tool_choice: Some(ToolChoice::Required),
tools: vec![ToolDefinition {
name: "test_tool".to_string(),
description: "A test tool".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {}
}),
}],
..minimal_request()
};
let aws_request = aws_request(request, false);
let tool_config = aws_request
.tools_config()
.expect("Should build tool config");
assert!(tool_config.is_some());
let config = tool_config.unwrap();
assert!(config.tool_choice().is_some());
assert!(matches!(
config.tool_choice().unwrap(),
aws_bedrock::ToolChoice::Any(_)
));
}
#[test]
fn test_tool_choice_none_conversion() {
let request = CompletionRequest {
model: None,
tool_choice: Some(ToolChoice::None),
tools: vec![ToolDefinition {
name: "test_tool".to_string(),
description: "A test tool".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {}
}),
}],
..minimal_request()
};
let aws_request = aws_request(request, false);
let tool_config = aws_request
.tools_config()
.expect("Should build tool config");
assert!(tool_config.is_some());
let config = tool_config.unwrap();
assert!(config.tool_choice().is_none());
}
#[test]
fn test_tool_choice_specific_conversion() {
let request = CompletionRequest {
model: None,
tool_choice: Some(ToolChoice::Specific {
function_names: vec!["specific_tool".to_string()],
}),
tools: vec![ToolDefinition {
name: "specific_tool".to_string(),
description: "A specific tool".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {}
}),
}],
..minimal_request()
};
let aws_request = aws_request(request, false);
let tool_config = aws_request
.tools_config()
.expect("Should build tool config");
assert!(tool_config.is_some());
let config = tool_config.unwrap();
assert!(config.tool_choice().is_some());
assert!(matches!(
config.tool_choice().unwrap(),
aws_bedrock::ToolChoice::Tool(specific) if specific.name() == "specific_tool"
));
}
#[test]
fn test_no_tool_choice_when_not_specified() {
let request = CompletionRequest {
model: None,
tool_choice: None, tools: vec![ToolDefinition {
name: "test_tool".to_string(),
description: "A test tool".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {}
}),
}],
..minimal_request()
};
let aws_request = aws_request(request, false);
let tool_config = aws_request
.tools_config()
.expect("Should build tool config");
assert!(tool_config.is_some());
let config = tool_config.unwrap();
assert!(config.tool_choice().is_none());
}
#[test]
fn test_tool_with_empty_parameters() {
let request = CompletionRequest {
model: None,
tools: vec![ToolDefinition {
name: "document_list".to_string(),
description: "Lists all documents".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {}
}),
}],
..minimal_request()
};
let aws_request = aws_request(request, false);
let tool_config = aws_request
.tools_config()
.expect("Should build tool config");
assert!(tool_config.is_some());
let config = tool_config.unwrap();
assert_eq!(config.tools().len(), 1);
assert!(
matches!(&config.tools()[0], aws_bedrock::Tool::ToolSpec(spec)
if spec.name() == "document_list"
&& spec.description() == Some("Lists all documents")
&& spec.input_schema().is_some()
)
);
}
#[test]
fn test_tool_with_parameters() {
let request = CompletionRequest {
model: None,
tools: vec![ToolDefinition {
name: "get_weather".to_string(),
description: "Get weather for a location".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "City name"
},
"units": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}),
}],
..minimal_request()
};
let aws_request = aws_request(request, false);
let tool_config = aws_request
.tools_config()
.expect("Should build tool config");
assert!(tool_config.is_some());
let config = tool_config.unwrap();
assert_eq!(config.tools().len(), 1);
assert!(
matches!(&config.tools()[0], aws_bedrock::Tool::ToolSpec(spec)
if spec.name() == "get_weather"
&& spec.description() == Some("Get weather for a location")
)
);
}
#[test]
fn test_system_prompt_includes_system_history() {
let request = CompletionRequest {
model: None,
preamble: None,
chat_history: OneOrMany::many(vec![
Message::system("History system instruction"),
Message::User {
content: OneOrMany::one(UserContent::Text(Text {
text: "test".to_string(),
})),
},
])
.expect("history should be non-empty"),
..minimal_request()
};
let aws_request = aws_request(request, false);
let system_prompt = aws_request
.system_prompt()
.expect("system prompt should build")
.expect("system prompt should exist");
assert_eq!(system_prompt.len(), 1);
assert_eq!(
system_prompt.first(),
Some(&aws_bedrock::SystemContentBlock::Text(
"History system instruction".to_string()
))
);
}
#[test]
fn test_system_prompt_appends_cache_point_when_prompt_caching_enabled() {
let request = CompletionRequest {
preamble: Some("System prompt".to_string()),
..minimal_request()
};
let aws_request = aws_request(request, true);
let system_prompt = aws_request
.system_prompt()
.expect("system prompt should build")
.expect("system prompt should exist");
assert_eq!(system_prompt.len(), 2);
assert_eq!(
system_prompt.first(),
Some(&aws_bedrock::SystemContentBlock::Text(
"System prompt".to_string()
))
);
assert!(matches!(
system_prompt.last(),
Some(aws_bedrock::SystemContentBlock::CachePoint(_))
));
}
#[test]
fn test_messages_exclude_system_history() {
let request = CompletionRequest {
model: None,
preamble: None,
chat_history: OneOrMany::many(vec![
Message::system("History system instruction"),
Message::User {
content: OneOrMany::one(UserContent::Text(Text {
text: "test".to_string(),
})),
},
])
.expect("history should be non-empty"),
..minimal_request()
};
let aws_request = aws_request(request, false);
let messages = aws_request.messages().expect("messages should convert");
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].role, aws_bedrock::ConversationRole::User);
}
#[test]
fn test_messages_append_cache_point_when_prompt_caching_enabled() {
let aws_request = aws_request(minimal_request(), true);
let messages = aws_request.messages().expect("messages should convert");
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].role, aws_bedrock::ConversationRole::User);
assert_eq!(messages[0].content.len(), 2);
assert!(matches!(
messages[0].content.last(),
Some(aws_bedrock::ContentBlock::CachePoint(_))
));
}
#[test]
fn test_messages_skip_cache_point_when_history_contains_reasoning() {
let reasoning =
rig_core::message::Reasoning::new_with_signature("thinking", Some("sig".to_string()));
let request = CompletionRequest {
chat_history: OneOrMany::many(vec![
Message::User {
content: OneOrMany::one(UserContent::Text(Text {
text: "user prompt".to_string(),
})),
},
Message::Assistant {
id: None,
content: OneOrMany::one(rig_core::completion::AssistantContent::Reasoning(
reasoning,
)),
},
Message::User {
content: OneOrMany::one(UserContent::Text(Text {
text: "follow up".to_string(),
})),
},
])
.expect("history should be non-empty"),
..minimal_request()
};
let aws_request = aws_request(request, true);
let messages = aws_request.messages().expect("messages should convert");
let last_message = messages.last().expect("messages should not be empty");
assert!(
!last_message
.content
.iter()
.any(|c| matches!(c, aws_bedrock::ContentBlock::CachePoint(_))),
"message-level cache point should be skipped when chat history contains reasoning"
);
let system_only = aws_request.system_prompt().expect("system prompt builds");
assert!(system_only.is_none() || !system_only.unwrap().is_empty());
}
#[test]
fn test_output_config_none_when_no_schema() {
let request = minimal_request();
let aws_request = aws_request(request, false);
assert!(
aws_request
.output_config()
.expect("output config builds")
.is_none()
);
}
#[test]
fn test_output_config_with_schema() {
let schema: schemars::Schema = serde_json::from_value(serde_json::json!({
"type": "object",
"title": "WeatherResponse",
"properties": {
"temperature": { "type": "number" },
"unit": { "type": "string", "enum": ["celsius", "fahrenheit"] }
},
"required": ["temperature", "unit"]
}))
.expect("valid schema");
let request = CompletionRequest {
output_schema: Some(schema),
..minimal_request()
};
let aws_request = aws_request(request, false);
let output_config = aws_request.output_config().expect("output config builds");
assert!(output_config.is_some());
let config = output_config.unwrap();
let text_format = config.text_format().expect("text_format should be set");
assert_eq!(
*text_format.r#type(),
aws_bedrock::OutputFormatType::JsonSchema
);
let structure = text_format.structure().expect("structure should be set");
let json_schema = structure
.as_json_schema()
.expect("should be JsonSchema variant");
assert_eq!(json_schema.name(), Some("WeatherResponse"));
let parsed: serde_json::Value =
serde_json::from_str(json_schema.schema()).expect("schema should be valid JSON");
assert_eq!(parsed["type"], "object");
assert!(parsed["properties"]["temperature"].is_object());
}
#[test]
fn test_output_config_uses_default_name() {
let schema: schemars::Schema = serde_json::from_value(serde_json::json!({
"type": "object",
"properties": {
"result": { "type": "string" }
}
}))
.expect("valid schema");
let request = CompletionRequest {
output_schema: Some(schema),
..minimal_request()
};
let aws_request = aws_request(request, false);
let config = aws_request
.output_config()
.expect("output config builds")
.expect("should have config");
let text_format = config.text_format().expect("text_format should be set");
let structure = text_format.structure().expect("structure should be set");
let json_schema = structure
.as_json_schema()
.expect("should be JsonSchema variant");
assert_eq!(json_schema.name(), Some("response_schema"));
}
}