Skip to main content

rig_bedrock/
client.rs

1use crate::image::ImageGenerationModel;
2use crate::{completion::CompletionModel, embedding::EmbeddingModel};
3use aws_config::{BehaviorVersion, Region};
4use rig::client::Nothing;
5use rig::prelude::*;
6use std::sync::Arc;
7use tokio::sync::OnceCell;
8
9pub const DEFAULT_AWS_REGION: &str = "us-east-1";
10
11#[derive(Clone)]
12pub struct ClientBuilder<'a> {
13    region: &'a str,
14}
15
16impl<'a> ClientBuilder<'a> {
17    /// Make sure to verify model and region [compatibility]
18    ///
19    /// [compatibility]: https://docs.aws.amazon.com/bedrock/latest/userguide/models-regions.html
20    pub fn region(mut self, region: &'a str) -> Self {
21        self.region = region;
22        self
23    }
24
25    /// Make sure you have permissions to access [Amazon Bedrock foundation model]
26    ///
27    /// [ Amazon Bedrock foundation model]: <https://docs.aws.amazon.com/bedrock/latest/userguide/model-access-modify.html>
28    pub async fn build(self) -> Client {
29        let sdk_config = aws_config::defaults(BehaviorVersion::latest())
30            .region(Region::new(String::from(self.region)))
31            .load()
32            .await;
33        let client = aws_sdk_bedrockruntime::Client::new(&sdk_config);
34        Client {
35            profile_name: None,
36            aws_client: Arc::new(OnceCell::from(client)),
37        }
38    }
39}
40
41impl Default for ClientBuilder<'_> {
42    fn default() -> Self {
43        Self {
44            region: DEFAULT_AWS_REGION,
45        }
46    }
47}
48
49#[derive(Clone, Debug)]
50pub struct Client {
51    profile_name: Option<String>,
52    pub(crate) aws_client: Arc<OnceCell<aws_sdk_bedrockruntime::Client>>,
53}
54
55impl From<aws_sdk_bedrockruntime::Client> for Client {
56    fn from(aws_client: aws_sdk_bedrockruntime::Client) -> Self {
57        Client {
58            profile_name: None,
59            aws_client: Arc::new(OnceCell::from(aws_client)),
60        }
61    }
62}
63
64impl Client {
65    fn new() -> Self {
66        Self {
67            profile_name: None,
68            aws_client: Arc::new(OnceCell::new()),
69        }
70    }
71
72    /// Create an AWS Bedrock client using AWS profile name
73    pub fn with_profile_name(profile_name: &str) -> Self {
74        Self {
75            profile_name: Some(profile_name.into()),
76            aws_client: Arc::new(OnceCell::new()),
77        }
78    }
79
80    pub async fn get_inner(&self) -> &aws_sdk_bedrockruntime::Client {
81        self.aws_client
82            .get_or_init(|| async {
83                let config = if let Some(profile_name) = &self.profile_name {
84                    aws_config::defaults(BehaviorVersion::latest())
85                        .profile_name(profile_name)
86                        .load()
87                        .await
88                } else {
89                    aws_config::load_from_env().await
90                };
91                aws_sdk_bedrockruntime::Client::new(&config)
92            })
93            .await
94    }
95}
96
97impl ProviderClient for Client {
98    type Input = Nothing;
99    type Error = rig::client::ProviderClientError;
100
101    fn from_env() -> Result<Self, Self::Error>
102    where
103        Self: Sized,
104    {
105        Ok(Client::new())
106    }
107
108    fn from_val(_: Nothing) -> Result<Self, Self::Error>
109    where
110        Self: Sized,
111    {
112        Err(rig::client::ProviderClientError::InvalidConfiguration(
113            "use `Client::from_env()` or `Client::with_profile_name(\"aws_profile\")` instead",
114        ))
115    }
116}
117
118impl CompletionClient for Client {
119    type CompletionModel = CompletionModel;
120
121    fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
122        CompletionModel::new(self.clone(), model)
123    }
124}
125
126impl EmbeddingsClient for Client {
127    type EmbeddingModel = EmbeddingModel;
128
129    fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
130        EmbeddingModel::new(self.clone(), model, None)
131    }
132
133    fn embedding_model_with_ndims(
134        &self,
135        model: impl Into<String>,
136        ndims: usize,
137    ) -> Self::EmbeddingModel {
138        EmbeddingModel::new(self.clone(), model, Some(ndims))
139    }
140}
141
142impl ImageGenerationClient for Client {
143    type ImageGenerationModel = ImageGenerationModel;
144
145    fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
146        ImageGenerationModel::new(self.clone(), model)
147    }
148}
149
150impl VerifyClient for Client {
151    async fn verify(&self) -> Result<(), VerifyError> {
152        // No API endpoint to verify the API key
153        Ok(())
154    }
155}