use crate::types::json::AwsDocument;
use crate::types::message::RigMessage;
use aws_sdk_bedrockruntime::types as aws_bedrock;
use aws_sdk_bedrockruntime::types::{
InferenceConfiguration, SystemContentBlock, Tool, ToolConfiguration, ToolInputSchema,
ToolSpecification,
};
use rig::OneOrMany;
use rig::completion::{CompletionError, Message};
use rig::message::{DocumentMediaType, UserContent};
pub struct AwsCompletionRequest(pub rig::completion::CompletionRequest);
impl AwsCompletionRequest {
pub fn additional_params(&self) -> Option<aws_smithy_types::Document> {
self.0
.additional_params
.to_owned()
.map(|params| params.into())
.map(|doc: AwsDocument| doc.0)
}
pub fn inference_config(&self) -> Option<InferenceConfiguration> {
let mut inference_configuration = InferenceConfiguration::builder();
if let Some(temperature) = &self.0.temperature {
inference_configuration =
inference_configuration.set_temperature(Some(*temperature as f32));
}
if let Some(max_tokens) = &self.0.max_tokens {
inference_configuration =
inference_configuration.set_max_tokens(Some(*max_tokens as i32));
}
Some(inference_configuration.build())
}
pub fn tools_config(&self) -> Result<Option<ToolConfiguration>, CompletionError> {
let mut tools = vec![];
for tool_definition in self.0.tools.iter() {
let doc: AwsDocument = tool_definition.parameters.clone().into();
let schema = ToolInputSchema::Json(doc.0);
let tool = Tool::ToolSpec(
ToolSpecification::builder()
.name(tool_definition.name.clone())
.set_description(Some(tool_definition.description.clone()))
.set_input_schema(Some(schema))
.build()
.map_err(|e| CompletionError::RequestError(e.into()))?,
);
tools.push(tool);
}
if !tools.is_empty() {
use aws_sdk_bedrockruntime::types as aws_bedrock;
let tool_choice = self.0.tool_choice.as_ref().and_then(|choice| {
match choice {
rig::message::ToolChoice::Auto => Some(aws_bedrock::ToolChoice::Auto(
aws_bedrock::AutoToolChoice::builder().build(),
)),
rig::message::ToolChoice::Required => Some(aws_bedrock::ToolChoice::Any(
aws_bedrock::AnyToolChoice::builder().build(),
)),
rig::message::ToolChoice::None => {
None
}
rig::message::ToolChoice::Specific { function_names } => {
function_names.first().map(|name| {
aws_bedrock::ToolChoice::Tool(
aws_bedrock::SpecificToolChoice::builder()
.name(name.clone())
.build()
.expect("Failed to build SpecificToolChoice"),
)
})
}
}
});
let config = ToolConfiguration::builder()
.set_tools(Some(tools))
.set_tool_choice(tool_choice)
.build()
.map_err(|e| CompletionError::RequestError(e.into()))?;
Ok(Some(config))
} else {
Ok(None)
}
}
pub fn system_prompt(&self) -> Option<Vec<SystemContentBlock>> {
let mut system_blocks = Vec::new();
if let Some(system_prompt) = self.0.preamble.to_owned()
&& !system_prompt.is_empty()
{
system_blocks.push(SystemContentBlock::Text(system_prompt));
}
for message in self.0.chat_history.iter() {
if let Message::System { content } = message
&& !content.is_empty()
{
system_blocks.push(SystemContentBlock::Text(content.clone()));
}
}
if system_blocks.is_empty() {
None
} else {
Some(system_blocks)
}
}
pub fn messages(&self) -> Result<Vec<aws_bedrock::Message>, CompletionError> {
let mut full_history: Vec<Message> = Vec::new();
if !self.0.documents.is_empty() {
let messages = self
.0
.documents
.iter()
.map(|doc| doc.to_string())
.collect::<Vec<_>>()
.join(" | ");
let content = OneOrMany::one(UserContent::document(
messages,
Some(DocumentMediaType::TXT),
));
full_history.push(Message::User { content });
}
self.0.chat_history.iter().for_each(|message| {
if !matches!(message, Message::System { .. }) {
full_history.push(message.clone());
}
});
full_history
.into_iter()
.map(|message| RigMessage(message).try_into())
.collect::<Result<Vec<aws_bedrock::Message>, _>>()
}
}
#[cfg(test)]
mod tests {
use super::*;
use rig::OneOrMany;
use rig::completion::{CompletionRequest, ToolDefinition};
use rig::message::{Message, Text, ToolChoice, UserContent};
fn minimal_request() -> CompletionRequest {
CompletionRequest {
model: None,
preamble: None,
chat_history: OneOrMany::one(Message::User {
content: OneOrMany::one(UserContent::Text(Text {
text: "test".to_string(),
})),
}),
documents: vec![],
tools: vec![],
temperature: None,
max_tokens: None,
tool_choice: None,
additional_params: None,
output_schema: None,
}
}
#[test]
fn test_tool_choice_auto_conversion() {
let request = CompletionRequest {
model: None,
tool_choice: Some(ToolChoice::Auto),
tools: vec![ToolDefinition {
name: "test_tool".to_string(),
description: "A test tool".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {}
}),
}],
..minimal_request()
};
let aws_request = AwsCompletionRequest(request);
let tool_config = aws_request
.tools_config()
.expect("Should build tool config");
assert!(tool_config.is_some());
let config = tool_config.unwrap();
assert!(config.tool_choice().is_some());
assert!(matches!(
config.tool_choice().unwrap(),
aws_bedrock::ToolChoice::Auto(_)
));
}
#[test]
fn test_tool_choice_required_conversion() {
let request = CompletionRequest {
model: None,
tool_choice: Some(ToolChoice::Required),
tools: vec![ToolDefinition {
name: "test_tool".to_string(),
description: "A test tool".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {}
}),
}],
..minimal_request()
};
let aws_request = AwsCompletionRequest(request);
let tool_config = aws_request
.tools_config()
.expect("Should build tool config");
assert!(tool_config.is_some());
let config = tool_config.unwrap();
assert!(config.tool_choice().is_some());
assert!(matches!(
config.tool_choice().unwrap(),
aws_bedrock::ToolChoice::Any(_)
));
}
#[test]
fn test_tool_choice_none_conversion() {
let request = CompletionRequest {
model: None,
tool_choice: Some(ToolChoice::None),
tools: vec![ToolDefinition {
name: "test_tool".to_string(),
description: "A test tool".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {}
}),
}],
..minimal_request()
};
let aws_request = AwsCompletionRequest(request);
let tool_config = aws_request
.tools_config()
.expect("Should build tool config");
assert!(tool_config.is_some());
let config = tool_config.unwrap();
assert!(config.tool_choice().is_none());
}
#[test]
fn test_tool_choice_specific_conversion() {
let request = CompletionRequest {
model: None,
tool_choice: Some(ToolChoice::Specific {
function_names: vec!["specific_tool".to_string()],
}),
tools: vec![ToolDefinition {
name: "specific_tool".to_string(),
description: "A specific tool".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {}
}),
}],
..minimal_request()
};
let aws_request = AwsCompletionRequest(request);
let tool_config = aws_request
.tools_config()
.expect("Should build tool config");
assert!(tool_config.is_some());
let config = tool_config.unwrap();
assert!(config.tool_choice().is_some());
assert!(matches!(
config.tool_choice().unwrap(),
aws_bedrock::ToolChoice::Tool(specific) if specific.name() == "specific_tool"
));
}
#[test]
fn test_no_tool_choice_when_not_specified() {
let request = CompletionRequest {
model: None,
tool_choice: None, tools: vec![ToolDefinition {
name: "test_tool".to_string(),
description: "A test tool".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {}
}),
}],
..minimal_request()
};
let aws_request = AwsCompletionRequest(request);
let tool_config = aws_request
.tools_config()
.expect("Should build tool config");
assert!(tool_config.is_some());
let config = tool_config.unwrap();
assert!(config.tool_choice().is_none());
}
#[test]
fn test_tool_with_empty_parameters() {
let request = CompletionRequest {
model: None,
tools: vec![ToolDefinition {
name: "document_list".to_string(),
description: "Lists all documents".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {}
}),
}],
..minimal_request()
};
let aws_request = AwsCompletionRequest(request);
let tool_config = aws_request
.tools_config()
.expect("Should build tool config");
assert!(tool_config.is_some());
let config = tool_config.unwrap();
assert_eq!(config.tools().len(), 1);
assert!(
matches!(&config.tools()[0], aws_bedrock::Tool::ToolSpec(spec)
if spec.name() == "document_list"
&& spec.description() == Some("Lists all documents")
&& spec.input_schema().is_some()
)
);
}
#[test]
fn test_tool_with_parameters() {
let request = CompletionRequest {
model: None,
tools: vec![ToolDefinition {
name: "get_weather".to_string(),
description: "Get weather for a location".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "City name"
},
"units": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}),
}],
..minimal_request()
};
let aws_request = AwsCompletionRequest(request);
let tool_config = aws_request
.tools_config()
.expect("Should build tool config");
assert!(tool_config.is_some());
let config = tool_config.unwrap();
assert_eq!(config.tools().len(), 1);
assert!(
matches!(&config.tools()[0], aws_bedrock::Tool::ToolSpec(spec)
if spec.name() == "get_weather"
&& spec.description() == Some("Get weather for a location")
)
);
}
#[test]
fn test_system_prompt_includes_system_history() {
let request = CompletionRequest {
model: None,
preamble: None,
chat_history: OneOrMany::many(vec![
Message::system("History system instruction"),
Message::User {
content: OneOrMany::one(UserContent::Text(Text {
text: "test".to_string(),
})),
},
])
.expect("history should be non-empty"),
..minimal_request()
};
let aws_request = AwsCompletionRequest(request);
let system_prompt = aws_request.system_prompt();
assert!(system_prompt.is_some());
let system_prompt = system_prompt.unwrap();
assert_eq!(system_prompt.len(), 1);
assert_eq!(
system_prompt[0],
aws_bedrock::SystemContentBlock::Text("History system instruction".to_string())
);
}
#[test]
fn test_messages_exclude_system_history() {
let request = CompletionRequest {
model: None,
preamble: None,
chat_history: OneOrMany::many(vec![
Message::system("History system instruction"),
Message::User {
content: OneOrMany::one(UserContent::Text(Text {
text: "test".to_string(),
})),
},
])
.expect("history should be non-empty"),
..minimal_request()
};
let aws_request = AwsCompletionRequest(request);
let messages = aws_request.messages().expect("messages should convert");
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].role, aws_bedrock::ConversationRole::User);
}
}