syncable_cli/bedrock/
image.rs1use super::client::Client;
2use super::types::errors::AwsSdkInvokeModelError;
3use super::types::text_to_image::{TextToImageGeneration, TextToImageResponse};
4use aws_smithy_types::Blob;
5use rig::image_generation::{
6 self, ImageGenerationError, ImageGenerationRequest, ImageGenerationResponse,
7};
8
9pub const AMAZON_TITAN_IMAGE_GENERATOR_V1: &str = "amazon.titan-image-generator-v1";
11pub const AMAZON_TITAN_IMAGE_GENERATOR_V2_0: &str = "amazon.titan-image-generator-v2:0";
13pub const AMAZON_NOVA_CANVAS: &str = "amazon.nova-canvas-v1:0";
15
16#[derive(Clone)]
17pub struct ImageGenerationModel {
18 pub(crate) client: Client,
19 pub model: String,
20}
21
22impl ImageGenerationModel {
23 pub fn new(client: Client, model: impl Into<String>) -> Self {
24 Self {
25 client,
26 model: model.into(),
27 }
28 }
29}
30
31impl image_generation::ImageGenerationModel for ImageGenerationModel {
32 type Response = TextToImageResponse;
33
34 type Client = Client;
35
36 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
37 Self::new(client.clone(), model)
38 }
39
40 async fn image_generation(
41 &self,
42 generation_request: ImageGenerationRequest,
43 ) -> Result<ImageGenerationResponse<Self::Response>, ImageGenerationError> {
44 let mut request = TextToImageGeneration::new(generation_request.prompt);
45 request.width(generation_request.width);
46 request.height(generation_request.height);
47
48 let body = serde_json::to_string(&request)?;
49 let model_response = self
50 .client
51 .get_inner()
52 .await
53 .invoke_model()
54 .model_id(self.model.as_str())
55 .content_type("application/json")
56 .accept("application/json")
57 .body(Blob::new(body))
58 .send()
59 .await
60 .map_err(|sdk_error| {
61 Into::<ImageGenerationError>::into(AwsSdkInvokeModelError(sdk_error))
62 })?;
63
64 let response_str = String::from_utf8(model_response.body.into_inner())
65 .map_err(|e| ImageGenerationError::ResponseError(e.to_string()))?;
66
67 let result: TextToImageResponse = serde_json::from_str(&response_str)
68 .map_err(|e| ImageGenerationError::ResponseError(e.to_string()))?;
69
70 result.try_into()
71 }
72}