firebase_rs_sdk/ai/models/
generative_model.rs

1use std::sync::Arc;
2
3use serde_json::Value;
4
5use crate::ai::api::AiService;
6use crate::ai::backend::BackendType;
7use crate::ai::error::AiResult;
8use crate::ai::requests::{PreparedRequest, RequestOptions, Task};
9
10/// Port of the Firebase JS SDK `GenerativeModel` class.
11///
12/// Reference: `packages/ai/src/models/generative-model.ts`.
13#[derive(Clone, Debug)]
14pub struct GenerativeModel {
15    service: Arc<AiService>,
16    model: String,
17    default_request_options: Option<RequestOptions>,
18}
19
20impl GenerativeModel {
21    /// Creates a new generative model bound to the provided `AiService`.
22    ///
23    /// This mirrors the TypeScript constructor, normalising the model name according to the selected
24    /// backend and capturing the service API settings for later requests.
25    pub fn new(
26        service: Arc<AiService>,
27        model_name: impl Into<String>,
28        request_options: Option<RequestOptions>,
29    ) -> AiResult<Self> {
30        let backend_type = service.backend_type();
31        let model = normalize_model_name(model_name.into(), backend_type);
32        Ok(Self {
33            service,
34            model,
35            default_request_options: request_options,
36        })
37    }
38
39    /// Returns the fully qualified model resource identifier.
40    pub fn model(&self) -> &str {
41        &self.model
42    }
43
44    /// Prepares a `generateContent` request using the stored API settings.
45    pub async fn prepare_generate_content_request(
46        &self,
47        body: Value,
48        request_options: Option<RequestOptions>,
49    ) -> AiResult<PreparedRequest> {
50        let factory = self.service.request_factory().await?;
51        let effective_options = request_options.or_else(|| self.default_request_options.clone());
52        factory.construct_request(
53            &self.model,
54            Task::GenerateContent,
55            false,
56            body,
57            effective_options,
58        )
59    }
60}
61
62fn normalize_model_name(model: String, backend_type: BackendType) -> String {
63    match backend_type {
64        BackendType::GoogleAi => normalize_google_ai_model(model),
65        BackendType::VertexAi => normalize_vertex_ai_model(model),
66    }
67}
68
69fn normalize_google_ai_model(model: String) -> String {
70    if model.starts_with("models/") {
71        model
72    } else {
73        format!("models/{model}")
74    }
75}
76
77fn normalize_vertex_ai_model(model: String) -> String {
78    if model.contains('/') {
79        if model.starts_with("models/") {
80            format!("publishers/google/{model}")
81        } else {
82            model
83        }
84    } else {
85        format!("publishers/google/models/{model}")
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92    use crate::ai::backend::Backend;
93    use crate::ai::public_types::AiOptions;
94    use crate::app::initialize_app;
95    use crate::app::{FirebaseAppSettings, FirebaseOptions};
96    use serde_json::json;
97    use std::time::Duration;
98
99    fn unique_settings() -> FirebaseAppSettings {
100        use std::sync::atomic::{AtomicUsize, Ordering};
101        static COUNTER: AtomicUsize = AtomicUsize::new(0);
102        FirebaseAppSettings {
103            name: Some(format!(
104                "gen-model-{}",
105                COUNTER.fetch_add(1, Ordering::SeqCst)
106            )),
107            ..Default::default()
108        }
109    }
110
111    async fn init_service(options: FirebaseOptions, backend: Option<Backend>) -> Arc<AiService> {
112        let app = initialize_app(options, Some(unique_settings()))
113            .await
114            .unwrap();
115        match backend {
116            Some(backend) => crate::ai::get_ai(
117                Some(app),
118                Some(AiOptions {
119                    backend: Some(backend),
120                    use_limited_use_app_check_tokens: None,
121                }),
122            )
123            .await
124            .unwrap(),
125            None => crate::ai::get_ai_service(Some(app)).await.unwrap(),
126        }
127    }
128
129    #[tokio::test(flavor = "current_thread")]
130    async fn normalizes_google_model_name() {
131        let service = init_service(
132            FirebaseOptions {
133                api_key: Some("api".into()),
134                project_id: Some("project".into()),
135                app_id: Some("app".into()),
136                ..Default::default()
137            },
138            None,
139        )
140        .await;
141        let model = GenerativeModel::new(service.clone(), "gemini-pro", None).unwrap();
142        assert_eq!(model.model(), "models/gemini-pro");
143
144        let already_prefixed = GenerativeModel::new(service, "models/gemini-pro", None).unwrap();
145        assert_eq!(already_prefixed.model(), "models/gemini-pro");
146    }
147
148    #[tokio::test(flavor = "current_thread")]
149    async fn normalizes_vertex_model_name_and_prepares_request() {
150        let service = init_service(
151            FirebaseOptions {
152                api_key: Some("api".into()),
153                project_id: Some("project".into()),
154                app_id: Some("app".into()),
155                ..Default::default()
156            },
157            Some(Backend::vertex_ai("us-central1")),
158        )
159        .await;
160        let model = GenerativeModel::new(
161            service,
162            "gemini-pro",
163            Some(RequestOptions {
164                timeout: Some(Duration::from_secs(5)),
165                base_url: Some("https://example.com".into()),
166            }),
167        )
168        .unwrap();
169
170        assert_eq!(model.model(), "publishers/google/models/gemini-pro");
171
172        let prepared = model
173            .prepare_generate_content_request(json!({"contents": []}), None)
174            .await
175            .unwrap();
176        assert_eq!(
177            prepared.url.as_str(),
178            "https://example.com/v1beta/projects/project/locations/us-central1/publishers/google/models/gemini-pro:generateContent"
179        );
180        assert_eq!(prepared.timeout, Duration::from_secs(5));
181    }
182}