langdb_core/executor/
image_generation.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use crate::handler::CallbackHandlerFn;
5use crate::handler::ModelEventWithDetails;
6use crate::llm_gateway::provider::Provider;
7use crate::model::image_generation::initialize_image_generation;
8use crate::model::types::ModelEvent;
9use crate::model::CredentialsIdent;
10use crate::models::ModelMetadata;
11use crate::types::engine::ImageGenerationModelDefinition;
12use crate::types::gateway::CreateImageRequest;
13use crate::types::image::ImagesResponse;
14use crate::types::provider::InferenceModelProvider;
15use crate::GatewayError;
16use crate::{
17 model::types::ModelEventType,
18 types::{
19 credentials::Credentials,
20 engine::{Model, ModelType},
21 gateway::CostCalculator,
22 },
23};
24use actix_web::HttpRequest;
25use tracing::Span;
26use tracing_futures::Instrument;
27
28use super::get_key_credentials;
29use super::ProvidersConfig;
30
31pub async fn handle_image_generation(
32 mut request: CreateImageRequest,
33 callback_handler: &CallbackHandlerFn,
34 llm_model: &ModelMetadata,
35 key_credentials: Option<&Credentials>,
36 cost_calculator: Arc<Box<dyn CostCalculator>>,
37 tags: HashMap<String, String>,
38 req: HttpRequest,
39) -> Result<ImagesResponse, GatewayError> {
40 let span = Span::current();
41 request.model = llm_model.inference_provider.model_name.clone();
42
43 let providers_config = req.app_data::<ProvidersConfig>().cloned();
44 let key = get_key_credentials(
45 key_credentials,
46 providers_config.as_ref(),
47 &llm_model.inference_provider.provider.to_string(),
48 );
49 let engine = Provider::get_image_engine_for_model(llm_model, &request, key.as_ref())?;
50
51 let api_provider_name = match &llm_model.inference_provider.provider {
52 InferenceModelProvider::Proxy(provider) => provider.clone(),
53 _ => engine.provider_name().to_string(),
54 };
55
56 let db_model = Model {
57 name: llm_model.model.clone(),
58 inference_model_name: llm_model.inference_provider.model_name.clone(),
59 provider_name: api_provider_name.clone(),
60 model_type: ModelType::ImageGeneration,
61 price: llm_model.price.clone(),
62 credentials_ident: match key_credentials {
63 Some(_) => CredentialsIdent::Own,
64 _ => CredentialsIdent::Langdb,
65 },
66 };
67
68 let image_model_definition = ImageGenerationModelDefinition {
69 name: llm_model.model.clone(),
70 engine,
71 db_model: db_model.clone(),
72 };
73
74 let cost_calculator = cost_calculator.clone();
75 let callback_handler = callback_handler.clone();
76 let (tx, mut rx) = tokio::sync::mpsc::channel::<Option<ModelEvent>>(1000);
77
78 let handle = tokio::spawn(async move {
79 let mut stop_event = None;
80 while let Some(Some(msg)) = rx.recv().await {
81 if let ModelEvent {
82 event: ModelEventType::ImageGenerationFinish(e),
83 ..
84 } = &msg
85 {
86 stop_event = Some(e.clone());
87 }
88
89 callback_handler.on_message(ModelEventWithDetails::new(msg, Some(db_model.clone())));
90 }
91
92 stop_event
93 });
94
95 let model = initialize_image_generation(
96 image_model_definition.clone(),
97 Some(cost_calculator.clone()),
98 llm_model.inference_provider.endpoint.as_deref(),
99 Some(llm_model.model_provider.as_str()),
100 )
101 .await
102 .map_err(|e| GatewayError::CustomError(e.to_string()))?;
103
104 let result = model
105 .create_new(&request, tx, tags.clone())
106 .instrument(span.clone())
107 .await?;
108
109 let _stop_event = handle.await.unwrap();
110
111 Ok(result)
112}