langdb_core/executor/
image_generation.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use crate::handler::CallbackHandlerFn;
5use crate::handler::ModelEventWithDetails;
6use crate::llm_gateway::provider::Provider;
7use crate::model::image_generation::initialize_image_generation;
8use crate::model::types::ModelEvent;
9use crate::model::CredentialsIdent;
10use crate::models::ModelMetadata;
11use crate::types::engine::ImageGenerationModelDefinition;
12use crate::types::gateway::CreateImageRequest;
13use crate::types::image::ImagesResponse;
14use crate::types::provider::InferenceModelProvider;
15use crate::GatewayError;
16use crate::{
17    model::types::ModelEventType,
18    types::{
19        credentials::Credentials,
20        engine::{Model, ModelType},
21        gateway::CostCalculator,
22    },
23};
24use actix_web::HttpRequest;
25use tracing::Span;
26use tracing_futures::Instrument;
27
28use super::get_key_credentials;
29use super::ProvidersConfig;
30
31pub async fn handle_image_generation(
32    mut request: CreateImageRequest,
33    callback_handler: &CallbackHandlerFn,
34    llm_model: &ModelMetadata,
35    key_credentials: Option<&Credentials>,
36    cost_calculator: Arc<Box<dyn CostCalculator>>,
37    tags: HashMap<String, String>,
38    req: HttpRequest,
39) -> Result<ImagesResponse, GatewayError> {
40    let span = Span::current();
41    request.model = llm_model.inference_provider.model_name.clone();
42
43    let providers_config = req.app_data::<ProvidersConfig>().cloned();
44    let key = get_key_credentials(
45        key_credentials,
46        providers_config.as_ref(),
47        &llm_model.inference_provider.provider.to_string(),
48    );
49    let engine = Provider::get_image_engine_for_model(llm_model, &request, key.as_ref())?;
50
51    let api_provider_name = match &llm_model.inference_provider.provider {
52        InferenceModelProvider::Proxy(provider) => provider.clone(),
53        _ => engine.provider_name().to_string(),
54    };
55
56    let db_model = Model {
57        name: llm_model.model.clone(),
58        inference_model_name: llm_model.inference_provider.model_name.clone(),
59        provider_name: api_provider_name.clone(),
60        model_type: ModelType::ImageGeneration,
61        price: llm_model.price.clone(),
62        credentials_ident: match key_credentials {
63            Some(_) => CredentialsIdent::Own,
64            _ => CredentialsIdent::Langdb,
65        },
66    };
67
68    let image_model_definition = ImageGenerationModelDefinition {
69        name: llm_model.model.clone(),
70        engine,
71        db_model: db_model.clone(),
72    };
73
74    let cost_calculator = cost_calculator.clone();
75    let callback_handler = callback_handler.clone();
76    let (tx, mut rx) = tokio::sync::mpsc::channel::<Option<ModelEvent>>(1000);
77
78    let handle = tokio::spawn(async move {
79        let mut stop_event = None;
80        while let Some(Some(msg)) = rx.recv().await {
81            if let ModelEvent {
82                event: ModelEventType::ImageGenerationFinish(e),
83                ..
84            } = &msg
85            {
86                stop_event = Some(e.clone());
87            }
88
89            callback_handler.on_message(ModelEventWithDetails::new(msg, Some(db_model.clone())));
90        }
91
92        stop_event
93    });
94
95    let model = initialize_image_generation(
96        image_model_definition.clone(),
97        Some(cost_calculator.clone()),
98        llm_model.inference_provider.endpoint.as_deref(),
99        Some(llm_model.model_provider.as_str()),
100    )
101    .await
102    .map_err(|e| GatewayError::CustomError(e.to_string()))?;
103
104    let result = model
105        .create_new(&request, tx, tags.clone())
106        .instrument(span.clone())
107        .await?;
108
109    let _stop_event = handle.await.unwrap();
110
111    Ok(result)
112}