rig_bedrock/
completion.rs1use crate::{
4 client::Client,
5 types::{
6 assistant_content::AwsConverseOutput, completion_request::AwsCompletionRequest,
7 converse_output::InternalConverseOutput, errors::AwsSdkConverseError,
8 },
9};
10
11use rig::completion::{self, CompletionError, CompletionRequest};
12use rig::streaming::StreamingCompletionResponse;
13use rig::telemetry::SpanCombinator;
14use tracing::Instrument;
15
16pub const AI21_JAMBA_1_5_LARGE: &str = "ai21.jamba-1-5-large-v1:0";
18pub const AI21_JAMBA_1_5_MINI: &str = "ai21.jamba-1-5-mini-v1:0";
20pub const AMAZON_NOVA_CANVAS: &str = "amazon.nova-canvas-v1:0";
22pub const AMAZON_NOVA_LITE: &str = "amazon.nova-lite-v1:0";
24pub const AMAZON_NOVA_MICRO: &str = "amazon.nova-micro-v1:0";
26pub const AMAZON_NOVA_PREMIER: &str = "amazon.nova-premier-v1:0";
28pub const AMAZON_NOVA_PRO: &str = "amazon.nova-pro-v1:0";
30pub const AMAZON_NOVA_REEL_V1_0: &str = "amazon.nova-reel-v1:0";
32pub const AMAZON_NOVA_REEL_V1_1: &str = "amazon.nova-reel-v1:1";
34pub const AMAZON_NOVA_SONIC: &str = "amazon.nova-sonic-v1:0";
36pub const AMAZON_RERANK_1_0: &str = "amazon.rerank-v1:0";
38pub const AMAZON_TITAN_EMBEDDINGS_G1_TEXT: &str = "amazon.titan-embed-text-v1";
40pub const AMAZON_TITAN_IMAGE_GENERATOR_G1_V2: &str = "amazon.titan-image-generator-v2:0";
42pub const AMAZON_TITAN_IMAGE_GENERATOR_G1: &str = "amazon.titan-image-generator-v1";
44pub const AMAZON_TITAN_MULTIMODAL_EMBEDDINGS_G1: &str = "amazon.titan-embed-image-v1";
46pub const AMAZON_TITAN_TEXT_EMBEDDINGS_V2: &str = "amazon.titan-embed-text-v2:0";
48pub const AMAZON_TITAN_TEXT_EXPRESS_V1: &str = "amazon.titan-text-express-v1";
50pub const AMAZON_TITAN_TEXT_LITE_V1: &str = "amazon.titan-text-lite-v1";
52pub const AMAZON_TITAN_TEXT_PREMIER_V1_0: &str = "amazon.titan-text-premier-v1:0";
54pub const ANTHROPIC_CLAUDE_3_HAIKU: &str = "anthropic.claude-3-haiku-20240307-v1:0";
56pub const ANTHROPIC_CLAUDE_3_OPUS: &str = "anthropic.claude-3-opus-20240229-v1:0";
58pub const ANTHROPIC_CLAUDE_3_SONNET: &str = "anthropic.claude-3-sonnet-20240229-v1:0";
60pub const ANTHROPIC_CLAUDE_3_5_HAIKU: &str = "anthropic.claude-3-5-haiku-20241022-v1:0";
62pub const ANTHROPIC_CLAUDE_3_5_SONNET_V2: &str = "anthropic.claude-3-5-sonnet-20241022-v2:0";
64pub const ANTHROPIC_CLAUDE_3_5_SONNET: &str = "anthropic.claude-3-5-sonnet-20240620-v1:0";
66pub const ANTHROPIC_CLAUDE_3_7_SONNET: &str = "anthropic.claude-3-7-sonnet-20250219-v1:0";
68pub const ANTHROPIC_CLAUDE_OPUS_4: &str = "anthropic.claude-opus-4-20250514-v1:0";
70pub const ANTHROPIC_CLAUDE_SONNET_4: &str = "anthropic.claude-sonnet-4-20250514-v1:0";
72pub const COHERE_COMMAND_LIGHT_TEXT: &str = "cohere.command-light-text-v14";
74pub const COHERE_COMMAND_R_PLUS: &str = "cohere.command-r-plus-v1:0";
76pub const COHERE_COMMAND_R: &str = "cohere.command-r-v1:0";
78pub const COHERE_COMMAND: &str = "cohere.command-text-v14";
80pub const COHERE_EMBED_ENGLISH: &str = "cohere.embed-english-v3";
82pub const COHERE_EMBED_MULTILINGUAL: &str = "cohere.embed-multilingual-v3";
84pub const COHERE_RERANK_V3_5: &str = "cohere.rerank-v3-5:0";
86pub const DEEPSEEK_R1: &str = "deepseek.r1-v1:0";
88pub const LUMA_RAY_V2_0: &str = "luma.ray-v2:0";
90pub const LLAMA_3_8B_INSTRUCT: &str = "meta.llama3-8b-instruct-v1:0";
92pub const LLAMA_3_70B_INSTRUCT: &str = "meta.llama3-70b-instruct-v1:0";
94pub const LLAMA_3_1_8B_INSTRUCT: &str = "meta.llama3-1-8b-instruct-v1:0";
96pub const LLAMA_3_1_70B_INSTRUCT: &str = "meta.llama3-1-70b-instruct-v1:0";
98pub const LLAMA_3_1_405B_INSTRUCT: &str = "meta.llama3-1-405b-instruct-v1:0";
100pub const LLAMA_3_2_1B_INSTRUCT: &str = "meta.llama3-2-1b-instruct-v1:0";
102pub const LLAMA_3_2_3B_INSTRUCT: &str = "meta.llama3-2-3b-instruct-v1:0";
104pub const LLAMA_3_2_11B_INSTRUCT: &str = "meta.llama3-2-11b-instruct-v1:0";
106pub const LLAMA_3_2_90B_INSTRUCT: &str = "meta.llama3-2-90b-instruct-v1:0";
108pub const META_LLAMA_3_3_70B_INSTRUCT: &str = "meta.llama3-3-70b-instruct-v1:0";
110pub const META_LLAMA_4_MAVERICK_17B_INSTRUCT: &str = "meta.llama4-maverick-17b-instruct-v1:0";
112pub const META_LLAMA_4_SCOUT_17B_INSTRUCT: &str = "meta.llama4-scout-17b-instruct-v1:0";
114pub const MISTRAL_7B_INSTRUCT: &str = "mistral.mistral-7b-instruct-v0:2";
116pub const MISTRAL_LARGE_24_02: &str = "mistral.mistral-large-2402-v1:0";
118pub const MISTRAL_LARGE_24_07: &str = "mistral.mistral-large-2407-v1:0";
120pub const MISTRAL_SMALL_24_02: &str = "mistral.mistral-small-2402-v1:0";
122pub const MISTRAL_MIXTRAL_8X7B_INSTRUCT_V0: &str = "mistral.mixtral-8x7b-instruct-v0:1";
124pub const MISTRAL_PIXTRAL_LARGE_2502: &str = "mistral.pixtral-large-2502-v1:0";
126pub const STABILITY_SD3_5_LARGE: &str = "stability.sd3-5-large-v1:0";
128pub const STABILITY_STABLE_IMAGE_CORE_1_0: &str = "stability.stable-image-core-v1:1";
130pub const STABILITY_STABLE_IMAGE_ULTRA_1_0: &str = "stability.stable-image-ultra-v1:1";
132pub const TWELVELABS_MARENGO_EMBED_V2_7: &str = "twelvelabs.marengo-embed-2-7-v1:0";
134pub const TWELVELABS_PEGASUS_V1_2: &str = "twelvelabs.pegasus-1-2-v1:0";
136pub const WRITER_PALMYRA_X4: &str = "writer.palmyra-x4-v1:0";
138pub const WRITER_PALMYRA_X5: &str = "writer.palmyra-x5-v1:0";
140pub const AI21_JAMBA_INSTRUCT: &str = "ai21.jamba-instruct-v1:0";
142pub const ANTHROPIC_CLAUDE_2_1: &str = "anthropic.claude-v2:1";
144pub const ANTHROPIC_CLAUDE_2: &str = "anthropic.claude-v2";
146pub const ANTHROPIC_CLAUDE_INSTANT: &str = "anthropic.claude-instant-v1";
148pub const ANTHROPIC_CLAUDE_INSTANT_V1_2: &str = "anthropic.claude-instant-v1:2";
150pub const ANTHROPIC_CLAUDE: &str = "anthropic.claude-v2:0";
152pub const STABILITY_SD3_LARGE_1_0: &str = "stability.sd3-large-v1:0";
154pub const STABILITY_SDXL_1_0: &str = "stability.stable-diffusion-xl-v1";
156pub const STABILITY_STABLE_IMAGE_CORE_1_0_V1_0: &str = "stability.stable-image-core-v1:0";
158pub const STABILITY_STABLE_IMAGE_ULTRA_1_0_V1_0: &str = "stability.stable-image-ultra-v1:0";
160
161#[derive(Clone)]
162pub struct CompletionModel {
163 pub(crate) client: Client,
164 pub model: String,
165 pub prompt_caching: bool,
170}
171
172impl CompletionModel {
173 pub fn new(client: Client, model: impl Into<String>) -> Self {
174 Self {
175 client,
176 model: model.into(),
177 prompt_caching: false,
178 }
179 }
180
181 pub fn with_prompt_caching(mut self) -> Self {
197 self.prompt_caching = true;
198 self
199 }
200}
201
202pub(crate) fn resolve_request_model(
203 default_model: &str,
204 completion_request: &CompletionRequest,
205) -> String {
206 completion_request
207 .model
208 .clone()
209 .unwrap_or_else(|| default_model.to_string())
210}
211
212impl completion::CompletionModel for CompletionModel {
213 type Response = AwsConverseOutput;
214 type StreamingResponse = crate::streaming::BedrockStreamingResponse;
215
216 type Client = Client;
217
218 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
219 Self::new(client.clone(), model)
220 }
221
222 async fn completion(
223 &self,
224 completion_request: completion::CompletionRequest,
225 ) -> Result<completion::CompletionResponse<AwsConverseOutput>, CompletionError> {
226 let request_model = resolve_request_model(&self.model, &completion_request);
227
228 let span = if tracing::Span::current().is_disabled() {
229 tracing::info_span!(
230 target: "rig::completions",
231 "chat",
232 gen_ai.operation.name = "chat",
233 gen_ai.provider.name = "aws_bedrock",
234 gen_ai.request.model = &request_model,
235 gen_ai.system_instructions = &completion_request.preamble,
236 gen_ai.response.id = tracing::field::Empty,
237 gen_ai.response.model = tracing::field::Empty,
238 gen_ai.usage.output_tokens = tracing::field::Empty,
239 gen_ai.usage.input_tokens = tracing::field::Empty,
240 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
241 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
242 )
243 } else {
244 tracing::Span::current()
245 };
246
247 let request = AwsCompletionRequest {
248 inner: completion_request,
249 prompt_caching: self.prompt_caching,
250 };
251
252 let mut converse_builder = self
253 .client
254 .get_inner()
255 .await
256 .converse()
257 .model_id(request_model.clone());
258
259 let tool_config = request.tools_config()?;
260 let messages = request.messages()?;
261 converse_builder = converse_builder
262 .set_additional_model_request_fields(request.additional_params())
263 .set_inference_config(request.inference_config())
264 .set_tool_config(tool_config)
265 .set_system(request.system_prompt())
266 .set_messages(Some(messages));
267
268 async move {
269 let response = converse_builder.send().await.map_err(|sdk_error| {
270 Into::<CompletionError>::into(AwsSdkConverseError(sdk_error))
271 })?;
272
273 let response: InternalConverseOutput = response.try_into().map_err(|x| {
274 CompletionError::ProviderError(format!("Type conversion error: {x}"))
275 })?;
276
277 let aws_output = AwsConverseOutput(response);
278
279 let span = tracing::Span::current();
280 span.record_response_metadata(&aws_output);
281 span.record_token_usage(&aws_output);
282
283 aws_output.try_into()
284 }
285 .instrument(span)
286 .await
287 }
288
289 async fn stream(
290 &self,
291 request: CompletionRequest,
292 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
293 CompletionModel::stream(self, request).await
294 }
295}