rig-bedrock 0.2.4

AWS Bedrock model provider for Rig integration.
Documentation
use crate::types::json::AwsDocument;
use crate::types::message::RigMessage;
use aws_sdk_bedrockruntime::types as aws_bedrock;
use aws_sdk_bedrockruntime::types::{
    InferenceConfiguration, SystemContentBlock, Tool, ToolConfiguration, ToolInputSchema,
    ToolSpecification,
};
use rig::OneOrMany;
use rig::completion::{CompletionError, Message};
use rig::message::{ContentFormat, DocumentMediaType, UserContent};

pub struct AwsCompletionRequest(pub rig::completion::CompletionRequest);

impl AwsCompletionRequest {
    pub fn additional_params(&self) -> Option<aws_smithy_types::Document> {
        self.0
            .additional_params
            .to_owned()
            .map(|params| params.into())
            .map(|doc: AwsDocument| doc.0)
    }

    pub fn inference_config(&self) -> Option<InferenceConfiguration> {
        let mut inference_configuration = InferenceConfiguration::builder();

        if let Some(temperature) = &self.0.temperature {
            inference_configuration =
                inference_configuration.set_temperature(Some(*temperature as f32));
        }

        if let Some(max_tokens) = &self.0.max_tokens {
            inference_configuration =
                inference_configuration.set_max_tokens(Some(*max_tokens as i32));
        }

        Some(inference_configuration.build())
    }

    pub fn tools_config(&self) -> Result<Option<ToolConfiguration>, CompletionError> {
        let mut tools = vec![];
        for tool_definition in self.0.tools.iter() {
            let doc: AwsDocument = tool_definition.parameters.clone().into();
            let schema = ToolInputSchema::Json(doc.0);
            let tool = Tool::ToolSpec(
                ToolSpecification::builder()
                    .name(tool_definition.name.clone())
                    .set_description(Some(tool_definition.description.clone()))
                    .set_input_schema(Some(schema))
                    .build()
                    .map_err(|e| CompletionError::RequestError(e.into()))?,
            );
            tools.push(tool);
        }

        if !tools.is_empty() {
            let config = ToolConfiguration::builder()
                .set_tools(Some(tools))
                .build()
                .map_err(|e| CompletionError::RequestError(e.into()))?;

            Ok(Some(config))
        } else {
            Ok(None)
        }
    }

    pub fn system_prompt(&self) -> Option<Vec<SystemContentBlock>> {
        self.0
            .preamble
            .to_owned()
            .map(|system_prompt| vec![SystemContentBlock::Text(system_prompt)])
    }

    pub fn messages(&self) -> Result<Vec<aws_bedrock::Message>, CompletionError> {
        let mut full_history: Vec<Message> = Vec::new();

        if !self.0.documents.is_empty() {
            let messages = self
                .0
                .documents
                .iter()
                .map(|doc| doc.to_string())
                .collect::<Vec<_>>()
                .join(" | ");

            let content = OneOrMany::one(UserContent::document(
                messages,
                Some(ContentFormat::String),
                Some(DocumentMediaType::TXT),
            ));

            full_history.push(Message::User { content });
        }

        self.0.chat_history.iter().for_each(|message| {
            full_history.push(message.clone());
        });

        full_history
            .into_iter()
            .map(|message| RigMessage(message).try_into())
            .collect::<Result<Vec<aws_bedrock::Message>, _>>()
    }
}