use anyhow::{anyhow, Result};
use async_trait::async_trait;
use aws_config::BehaviorVersion;
use aws_sdk_bedrockruntime::{
types::{ContentBlock, ConversationRole, InferenceConfiguration, Message, SystemContentBlock},
Client,
};
use log::info;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::completions::ThinkingLevel;
use crate::constants::{AWS_BEDROCK_API_URL, AWS_REGION};
use crate::domain::RateLimit;
use crate::llm_models::{LLMModel, LLMTools};
#[derive(Serialize, Deserialize)]
struct AwsBedrockRequestBody {
instructions: String,
json_schema: Value,
max_tokens: i32,
temperature: f32,
}
#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)]
pub enum AwsBedrockModels {
NovaPro,
NovaLite,
NovaMicro,
}
#[async_trait(?Send)]
impl LLMModel for AwsBedrockModels {
fn as_str(&self) -> &str {
match self {
AwsBedrockModels::NovaPro => "amazon.nova-pro-v1:0",
AwsBedrockModels::NovaLite => "amazon.nova-lite-v1:0",
AwsBedrockModels::NovaMicro => "amazon.nova-micro-v1:0",
}
}
fn try_from_str(name: &str) -> Option<Self> {
match name.to_lowercase().as_str() {
"amazon.nova-pro-v1:0" => Some(AwsBedrockModels::NovaPro),
"amazon.nova-lite-v1:0" => Some(AwsBedrockModels::NovaLite),
"amazon.nova-micro-v1:0" => Some(AwsBedrockModels::NovaMicro),
_ => None,
}
}
fn default_max_tokens(&self) -> usize {
match self {
AwsBedrockModels::NovaPro => 5_120,
AwsBedrockModels::NovaLite => 5_120,
AwsBedrockModels::NovaMicro => 5_120,
}
}
fn get_endpoint(&self) -> String {
format!("{}/model/{}/converse", &*AWS_BEDROCK_API_URL, self.as_str())
}
fn get_body(
&self,
instructions: &str,
json_schema: &Value,
_function_call: bool,
max_tokens: &usize,
temperature: &f32,
_tools: Option<&[LLMTools]>,
_thinking_level: Option<&ThinkingLevel>,
) -> serde_json::Value {
let body = AwsBedrockRequestBody {
instructions: instructions.to_string(),
json_schema: json_schema.clone(),
max_tokens: *max_tokens as i32,
temperature: *temperature,
};
serde_json::to_value(body).unwrap()
}
async fn call_api(
&self,
_api_key: &str,
_version: Option<String>,
body: &serde_json::Value,
debug: bool,
_tools: Option<&[LLMTools]>,
) -> Result<String> {
let sdk_config = aws_config::defaults(BehaviorVersion::latest())
.region(&**AWS_REGION)
.load()
.await;
if debug {
info!("[debug] AWS Bedrock SDK config: {:#?}", sdk_config);
}
let client = Client::new(&sdk_config);
let request_body_opt: Option<AwsBedrockRequestBody> =
serde_json::from_value(body.clone()).ok();
let (instructions_opt, json_schema_opt, max_tokens_opt, temperature_opt) = request_body_opt
.map_or_else(
|| (None, None, None, None),
|request_body| {
(
Some(request_body.instructions),
Some(request_body.json_schema),
Some(request_body.max_tokens),
Some(request_body.temperature),
)
},
);
let base_instructions = self.get_base_instructions(None);
let converse_builder = client
.converse()
.model_id(self.as_str())
.system(SystemContentBlock::Text(base_instructions));
let instructions = instructions_opt.unwrap_or_default();
let user_instructions = json_schema_opt
.map(|schema| {
format!(
"<instructions>
{instructions}
</instructions>
<output json schema>
{schema}
</output json schema>"
)
})
.unwrap_or(instructions);
let converse_builder = converse_builder.messages(
Message::builder()
.role(ConversationRole::User)
.content(ContentBlock::Text(user_instructions))
.build()
.map_err(|_| anyhow!("failed to build message"))?,
);
let converse_builder = if max_tokens_opt.is_some() || temperature_opt.is_some() {
let inference_config = InferenceConfiguration::builder()
.set_max_tokens(max_tokens_opt)
.set_temperature(temperature_opt)
.build();
converse_builder.set_inference_config(Some(inference_config))
} else {
converse_builder
};
let converse_response = converse_builder.send().await?;
if debug {
info!(
"[debug] AWS Bedrock API response: {:#?}",
&converse_response
);
}
let text = converse_response
.output()
.ok_or(anyhow!("no output"))?
.as_message()
.map_err(|_| anyhow!("output not a message"))?
.content()
.first()
.ok_or(anyhow!("no content in message"))?
.as_text()
.map_err(|_| anyhow!("content is not text"))?
.to_string();
Ok(self.sanitize_json_response(&text))
}
fn get_data(&self, response_text: &str, _function_call: bool) -> Result<String> {
Ok(response_text.to_string())
}
fn get_rate_limit(&self) -> RateLimit {
match self {
AwsBedrockModels::NovaPro => RateLimit {
tpm: 400_000,
rpm: 100,
},
AwsBedrockModels::NovaLite | AwsBedrockModels::NovaMicro => RateLimit {
tpm: 2_000_000,
rpm: 1_000,
},
}
}
}