Skip to main content

rig_gemini_grpc/
client.rs

1use rig::prelude::*;
2use std::fmt::Debug;
3use tonic::metadata::MetadataValue;
4use tonic::service::Interceptor;
5use tonic::transport::{Channel, Endpoint};
6use tonic::{Request, Status};
7
8use super::GenerativeServiceClient;
9use crate::completion::CompletionModel;
10use crate::embedding::EmbeddingModel;
11
12// ================================================================
13// Google Gemini gRPC Client
14// ================================================================
15const GEMINI_GRPC_ENDPOINT: &str = "https://generativelanguage.googleapis.com";
16
17/// User agent identifier for API tracking
18const RIG_GRPC_CLIENT_IDENTIFIER: &str = "rig-grpc/0.1.0";
19
20#[derive(Clone)]
21pub struct Client {
22    api_key: String,
23    channel: Channel,
24}
25
26impl Debug for Client {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("Client")
29            .field("api_key", &"******")
30            .field("channel", &"Channel")
31            .finish()
32    }
33}
34
35// Interceptor to add API key and client identification to metadata
36#[derive(Clone)]
37pub struct ApiKeyInterceptor {
38    api_key: MetadataValue<tonic::metadata::Ascii>,
39    client_id: MetadataValue<tonic::metadata::Ascii>,
40}
41
42impl Interceptor for ApiKeyInterceptor {
43    fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
44        request
45            .metadata_mut()
46            .insert("x-goog-api-key", self.api_key.clone());
47        request
48            .metadata_mut()
49            .insert("x-goog-api-client", self.client_id.clone());
50        Ok(request)
51    }
52}
53
54impl Client {
55    /// Create a gRPC client with the given API key
56    pub async fn new(
57        api_key: impl Into<String>,
58    ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
59        let api_key = api_key.into();
60        let endpoint = Endpoint::from_static(GEMINI_GRPC_ENDPOINT).tls_config(
61            tonic::transport::ClientTlsConfig::new()
62                .with_webpki_roots()
63                .domain_name("generativelanguage.googleapis.com"),
64        )?;
65
66        let channel = endpoint.connect().await?;
67
68        Ok(Self { api_key, channel })
69    }
70
71    /// Get a gRPC client with API key interceptor
72    pub(crate) fn grpc_client(
73        &self,
74    ) -> Result<
75        GenerativeServiceClient<
76            tonic::service::interceptor::InterceptedService<Channel, ApiKeyInterceptor>,
77        >,
78        Box<dyn std::error::Error + Send + Sync>,
79    > {
80        let api_key = MetadataValue::try_from(&self.api_key)?;
81        let client_id = MetadataValue::try_from(RIG_GRPC_CLIENT_IDENTIFIER)?;
82        let interceptor = ApiKeyInterceptor { api_key, client_id };
83
84        Ok(GenerativeServiceClient::with_interceptor(
85            self.channel.clone(),
86            interceptor,
87        ))
88    }
89}
90
91impl ProviderClient for Client {
92    type Input = String;
93    type Error = Box<dyn std::error::Error + Send + Sync>;
94
95    /// Create a new Google Gemini gRPC client from the `GEMINI_API_KEY` environment variable.
96    fn from_env() -> Result<Self, Self::Error> {
97        let api_key = std::env::var("GEMINI_API_KEY")?;
98        tokio::task::block_in_place(|| {
99            tokio::runtime::Handle::current().block_on(Self::new(api_key))
100        })
101    }
102
103    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
104        tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(Self::new(input)))
105    }
106}
107
108impl CompletionClient for Client {
109    type CompletionModel = CompletionModel;
110
111    fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
112        CompletionModel::new(self.clone(), model)
113    }
114}
115
116impl EmbeddingsClient for Client {
117    type EmbeddingModel = EmbeddingModel;
118
119    fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
120        EmbeddingModel::new(self.clone(), model, None)
121    }
122
123    fn embedding_model_with_ndims(
124        &self,
125        model: impl Into<String>,
126        ndims: usize,
127    ) -> Self::EmbeddingModel {
128        EmbeddingModel::new(self.clone(), model, Some(ndims))
129    }
130}