1use crate::image::ImageGenerationModel;
2use crate::{completion::CompletionModel, embedding::EmbeddingModel};
3use aws_config::{BehaviorVersion, Region};
4use rig::client::Nothing;
5use rig::prelude::*;
6use std::sync::Arc;
7use tokio::sync::OnceCell;
8
9pub const DEFAULT_AWS_REGION: &str = "us-east-1";
10
11#[derive(Clone)]
12pub struct ClientBuilder<'a> {
13 region: &'a str,
14}
15
16impl<'a> ClientBuilder<'a> {
17 pub fn region(mut self, region: &'a str) -> Self {
21 self.region = region;
22 self
23 }
24
25 pub async fn build(self) -> Client {
29 let sdk_config = aws_config::defaults(BehaviorVersion::latest())
30 .region(Region::new(String::from(self.region)))
31 .load()
32 .await;
33 let client = aws_sdk_bedrockruntime::Client::new(&sdk_config);
34 Client {
35 profile_name: None,
36 aws_client: Arc::new(OnceCell::from(client)),
37 }
38 }
39}
40
41impl Default for ClientBuilder<'_> {
42 fn default() -> Self {
43 Self {
44 region: DEFAULT_AWS_REGION,
45 }
46 }
47}
48
49#[derive(Clone, Debug)]
50pub struct Client {
51 profile_name: Option<String>,
52 pub(crate) aws_client: Arc<OnceCell<aws_sdk_bedrockruntime::Client>>,
53}
54
55impl From<aws_sdk_bedrockruntime::Client> for Client {
56 fn from(aws_client: aws_sdk_bedrockruntime::Client) -> Self {
57 Client {
58 profile_name: None,
59 aws_client: Arc::new(OnceCell::from(aws_client)),
60 }
61 }
62}
63
64impl Client {
65 fn new() -> Self {
66 Self {
67 profile_name: None,
68 aws_client: Arc::new(OnceCell::new()),
69 }
70 }
71
72 pub fn with_profile_name(profile_name: &str) -> Self {
74 Self {
75 profile_name: Some(profile_name.into()),
76 aws_client: Arc::new(OnceCell::new()),
77 }
78 }
79
80 pub async fn get_inner(&self) -> &aws_sdk_bedrockruntime::Client {
81 self.aws_client
82 .get_or_init(|| async {
83 let config = if let Some(profile_name) = &self.profile_name {
84 aws_config::defaults(BehaviorVersion::latest())
85 .profile_name(profile_name)
86 .load()
87 .await
88 } else {
89 aws_config::load_from_env().await
90 };
91 aws_sdk_bedrockruntime::Client::new(&config)
92 })
93 .await
94 }
95}
96
97impl ProviderClient for Client {
98 type Input = Nothing;
99 type Error = rig::client::ProviderClientError;
100
101 fn from_env() -> Result<Self, Self::Error>
102 where
103 Self: Sized,
104 {
105 Ok(Client::new())
106 }
107
108 fn from_val(_: Nothing) -> Result<Self, Self::Error>
109 where
110 Self: Sized,
111 {
112 Err(rig::client::ProviderClientError::InvalidConfiguration(
113 "use `Client::from_env()` or `Client::with_profile_name(\"aws_profile\")` instead",
114 ))
115 }
116}
117
118impl CompletionClient for Client {
119 type CompletionModel = CompletionModel;
120
121 fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
122 CompletionModel::new(self.clone(), model)
123 }
124}
125
126impl EmbeddingsClient for Client {
127 type EmbeddingModel = EmbeddingModel;
128
129 fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
130 EmbeddingModel::new(self.clone(), model, None)
131 }
132
133 fn embedding_model_with_ndims(
134 &self,
135 model: impl Into<String>,
136 ndims: usize,
137 ) -> Self::EmbeddingModel {
138 EmbeddingModel::new(self.clone(), model, Some(ndims))
139 }
140}
141
142impl ImageGenerationClient for Client {
143 type ImageGenerationModel = ImageGenerationModel;
144
145 fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
146 ImageGenerationModel::new(self.clone(), model)
147 }
148}
149
150impl VerifyClient for Client {
151 async fn verify(&self) -> Result<(), VerifyError> {
152 Ok(())
154 }
155}