Skip to main content

chalk_client/
grpc_client.rs

1//! gRPC client for the Chalk feature store.
2
3use tonic::transport::{Channel, ClientTlsConfig};
4
5use crate::auth::TokenManager;
6use crate::config::{ChalkClientConfig, ChalkClientConfigBuilder, ensure_scheme};
7use crate::error::{ChalkClientError, Result};
8use crate::gen::chalk::common::v1::{
9    OnlineQueryBulkRequest as ProtoOnlineQueryBulkRequest,
10    OnlineQueryBulkResponse as ProtoOnlineQueryBulkResponse,
11    OnlineQueryRequest as ProtoOnlineQueryRequest,
12    OnlineQueryResponse as ProtoOnlineQueryResponse,
13    UploadFeaturesBulkRequest as ProtoUploadFeaturesBulkRequest,
14    UploadFeaturesBulkResponse as ProtoUploadFeaturesBulkResponse,
15};
16use crate::gen::chalk::engine::v1::query_service_client::QueryServiceClient;
17
18const USER_AGENT: &str = "chalk-rust-grpc/0.1.0";
19
20/// A gRPC client for the Chalk feature store.
21///
22/// [`ChalkGrpcClient`] is an alternative to [`ChalkClient`](crate::ChalkClient)
23/// that uses gRPC (HTTP/2 + Protocol Buffers) instead of REST/JSON for lower
24/// latency and higher throughput.
25///
26/// Supports [`query_proto`](Self::query_proto), [`query_bulk_proto`](Self::query_bulk_proto),
27/// and [`upload_features_proto`](Self::upload_features_proto). These are
28/// low-level methods that accept raw protobuf types. Offline queries are
29/// only available via the REST client.
30///
31/// # Example
32///
33/// ```rust,no_run
34/// use chalk_client::ChalkGrpcClient;
35/// use chalk_client::gen::chalk::common::v1::{OnlineQueryRequest, OutputExpr};
36/// use std::collections::HashMap;
37///
38/// # async fn example() -> chalk_client::error::Result<()> {
39/// let client = ChalkGrpcClient::new()
40///     .client_id("your-client-id")
41///     .client_secret("your-client-secret")
42///     .environment("production")
43///     .build()
44///     .await?;
45///
46/// let request = OnlineQueryRequest {
47///     inputs: HashMap::from([(
48///         "user.id".to_string(),
49///         prost_types::Value {
50///             kind: Some(prost_types::value::Kind::NumberValue(42.0)),
51///         },
52///     )]),
53///     outputs: vec![OutputExpr {
54///         expr: Some(chalk_client::gen::chalk::common::v1::output_expr::Expr::FeatureFqn(
55///             "user.name".to_string(),
56///         )),
57///     }],
58///     ..Default::default()
59/// };
60///
61/// let response = client.query_proto(request).await?;
62/// # Ok(())
63/// # }
64/// ```
65pub struct ChalkGrpcClient {
66    config: ChalkClientConfig,
67    token_manager: TokenManager,
68    grpc_client: QueryServiceClient<Channel>,
69    environment_id: String,
70}
71
72/// Builder for [`ChalkGrpcClient`].
73pub struct ChalkGrpcClientBuilder {
74    config_builder: ChalkClientConfigBuilder,
75}
76
77#[allow(clippy::new_ret_no_self)]
78impl ChalkGrpcClient {
79    /// Creates a new [`ChalkGrpcClientBuilder`] with authentication settings configured.
80    ///
81    /// Configuration is resolved from the first available source:
82    /// 1. Explicit values passed to the builder.
83    /// 2. Environment variables: `CHALK_CLIENT_ID`, `CHALK_CLIENT_SECRET`,
84    ///    `CHALK_API_SERVER`, `CHALK_ACTIVE_ENVIRONMENT`.
85    /// 3. `~/.chalk.yml` file, created by running `chalk login`.
86    pub fn new() -> ChalkGrpcClientBuilder {
87        ChalkGrpcClientBuilder {
88            config_builder: ChalkClientConfigBuilder::new(),
89        }
90    }
91}
92
93impl ChalkGrpcClientBuilder {
94    pub fn client_id(mut self, id: impl Into<String>) -> Self {
95        self.config_builder = self.config_builder.client_id(id);
96        self
97    }
98
99    pub fn client_secret(mut self, secret: impl Into<String>) -> Self {
100        self.config_builder = self.config_builder.client_secret(secret);
101        self
102    }
103
104    pub fn api_server(mut self, url: impl Into<String>) -> Self {
105        self.config_builder = self.config_builder.api_server(url);
106        self
107    }
108
109    pub fn environment(mut self, env: impl Into<String>) -> Self {
110        self.config_builder = self.config_builder.environment(env);
111        self
112    }
113
114    /// If specified, Chalk will route all requests from this client
115    /// to the relevant branch.
116    pub fn branch_id(mut self, id: impl Into<String>) -> Self {
117        self.config_builder = self.config_builder.branch_id(id);
118        self
119    }
120
121    /// Chalk can route queries to specific deployments using deployment tags.
122    pub fn deployment_tag(mut self, tag: impl Into<String>) -> Self {
123        self.config_builder = self.config_builder.deployment_tag(tag);
124        self
125    }
126
127    /// Chalk routes performance-sensitive requests like online query directly
128    /// to the query engine. Set this to override the automatically resolved
129    /// query server URL.
130    pub fn query_server(mut self, url: impl Into<String>) -> Self {
131        self.config_builder = self.config_builder.query_server(url);
132        self
133    }
134
135    /// Build the gRPC client, exchanging credentials for a token and
136    /// establishing an HTTP/2 connection to the query engine.
137    pub async fn build(self) -> Result<ChalkGrpcClient> {
138        let config = self.config_builder.build()?;
139        let token_manager = TokenManager::new(config.clone());
140        let token = token_manager.get_token().await?;
141
142        let environment_id = config
143            .environment
144            .clone()
145            .or(token.primary_environment.clone())
146            .ok_or_else(|| {
147                ChalkClientError::Config(
148                    "no environment specified and token has no primary_environment".into(),
149                )
150            })?;
151
152        // Priority: explicit query_server > grpc_engines > engines > api_server
153        let grpc_url = ensure_scheme(
154            config
155                .query_server
156                .clone()
157                .or_else(|| token.grpc_engines.get(&environment_id).cloned())
158                .or_else(|| token.engines.get(&environment_id).cloned())
159                .unwrap_or_else(|| config.api_server.clone()),
160        );
161
162        tracing::info!(
163            environment = %environment_id,
164            grpc_url = %grpc_url,
165            "connecting gRPC channel"
166        );
167
168        let mut endpoint = Channel::from_shared(grpc_url.clone()).map_err(|e| {
169            ChalkClientError::Config(format!("invalid gRPC URL '{}': {}", grpc_url, e))
170        })?;
171
172        if grpc_url.starts_with("https://") {
173            endpoint = endpoint
174                .tls_config(ClientTlsConfig::new().with_native_roots())
175                .map_err(|e| {
176                    ChalkClientError::Config(format!("TLS configuration error: {}", e))
177                })?;
178        }
179
180        let channel = endpoint.connect().await?;
181
182        let grpc_client = QueryServiceClient::new(channel);
183
184        tracing::info!("ChalkGrpcClient connected to {}", grpc_url);
185
186        Ok(ChalkGrpcClient {
187            config,
188            token_manager,
189            grpc_client,
190            environment_id,
191        })
192    }
193}
194
195impl ChalkGrpcClient {
196    /// Low-level: computes feature values for a single entity using the raw
197    /// protobuf request/response types.
198    ///
199    /// Prefer a higher-level wrapper (when available) over constructing proto
200    /// messages by hand. See <https://docs.chalk.ai/docs/query-basics>.
201    pub async fn query_proto(
202        &self,
203        request: ProtoOnlineQueryRequest,
204    ) -> Result<ProtoOnlineQueryResponse> {
205        let mut client = self.grpc_client.clone();
206        let mut req = tonic::Request::new(request);
207        self.inject_metadata(req.metadata_mut()).await?;
208        let response = client.online_query(req).await?;
209        Ok(response.into_inner())
210    }
211
212    /// Low-level: computes feature values for multiple entities at once using
213    /// the raw protobuf request/response types.
214    ///
215    /// Inputs and outputs use Arrow IPC (Feather) encoding inside the proto
216    /// messages.
217    pub async fn query_bulk_proto(
218        &self,
219        request: ProtoOnlineQueryBulkRequest,
220    ) -> Result<ProtoOnlineQueryBulkResponse> {
221        let mut client = self.grpc_client.clone();
222        let mut req = tonic::Request::new(request);
223        self.inject_metadata(req.metadata_mut()).await?;
224        let response = client.online_query_bulk(req).await?;
225        Ok(response.into_inner())
226    }
227
228    /// Low-level: uploads pre-computed feature values using the raw protobuf
229    /// request/response types.
230    pub async fn upload_features_proto(
231        &self,
232        request: ProtoUploadFeaturesBulkRequest,
233    ) -> Result<ProtoUploadFeaturesBulkResponse> {
234        let mut client = self.grpc_client.clone();
235        let mut req = tonic::Request::new(request);
236        self.inject_metadata(req.metadata_mut()).await?;
237        let response = client.upload_features_bulk(req).await?;
238        Ok(response.into_inner())
239    }
240
241    /// Returns the resolved environment ID.
242    pub fn environment_id(&self) -> &str {
243        &self.environment_id
244    }
245
246    /// Returns the current client configuration.
247    pub fn config(&self) -> &ChalkClientConfig {
248        &self.config
249    }
250
251    async fn inject_metadata(&self, metadata: &mut tonic::metadata::MetadataMap) -> Result<()> {
252        let token = self.token_manager.get_token().await?;
253
254        metadata.insert(
255            "authorization",
256            format!("Bearer {}", token.access_token)
257                .parse()
258                .map_err(|e| {
259                    ChalkClientError::Auth(format!("invalid token for metadata: {}", e))
260                })?,
261        );
262        metadata.insert(
263            "x-chalk-env-id",
264            self.environment_id
265                .parse()
266                .map_err(|e| ChalkClientError::Config(format!("invalid env ID: {}", e)))?,
267        );
268        metadata.insert(
269            "x-chalk-client-id",
270            self.config
271                .client_id
272                .parse()
273                .map_err(|e| ChalkClientError::Config(format!("invalid client ID: {}", e)))?,
274        );
275        metadata.insert(
276            "user-agent",
277            USER_AGENT
278                .parse()
279                .map_err(|e| ChalkClientError::Config(format!("invalid user-agent: {}", e)))?,
280        );
281        metadata.insert(
282            "x-chalk-deployment-type",
283            "engine-grpc".parse().unwrap(),
284        );
285        metadata.insert("x-chalk-server", "engine".parse().unwrap());
286
287        if let Some(ref branch) = self.config.branch_id {
288            metadata.insert(
289                "x-chalk-branch-id",
290                branch
291                    .parse()
292                    .map_err(|e| ChalkClientError::Config(format!("invalid branch ID: {}", e)))?,
293            );
294        }
295        if let Some(ref tag) = self.config.deployment_tag {
296            metadata.insert(
297                "x-chalk-deployment-tag",
298                tag.parse().map_err(|e| {
299                    ChalkClientError::Config(format!("invalid deployment tag: {}", e))
300                })?,
301            );
302        }
303
304        Ok(())
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[tokio::test]
313    async fn test_metadata_injection() {
314        let mut server = mockito::Server::new_async().await;
315
316        server
317            .mock("POST", "/v1/oauth/token")
318            .with_status(200)
319            .with_header("content-type", "application/json")
320            .with_body(
321                serde_json::json!({
322                    "access_token": "test-grpc-jwt",
323                    "expires_in": 3600,
324                    "primary_environment": "env-1",
325                    "engines": {},
326                    "grpc_engines": {}
327                })
328                .to_string(),
329            )
330            .create_async()
331            .await;
332
333        let config = ChalkClientConfigBuilder::new()
334            .client_id("grpc-test-id")
335            .client_secret("grpc-test-secret")
336            .api_server(&server.url())
337            .environment("env-1")
338            .branch_id("branch-42")
339            .deployment_tag("canary")
340            .build()
341            .unwrap();
342
343        let token_manager = TokenManager::new(config.clone());
344        let token = token_manager.get_token().await.unwrap();
345        assert_eq!(token.access_token, "test-grpc-jwt");
346
347        let mut metadata = tonic::metadata::MetadataMap::new();
348        metadata.insert(
349            "authorization",
350            format!("Bearer {}", token.access_token).parse().unwrap(),
351        );
352        metadata.insert("x-chalk-env-id", "env-1".parse().unwrap());
353        metadata.insert("x-chalk-client-id", "grpc-test-id".parse().unwrap());
354        metadata.insert("user-agent", USER_AGENT.parse().unwrap());
355        metadata.insert("x-chalk-branch-id", "branch-42".parse().unwrap());
356        metadata.insert("x-chalk-deployment-tag", "canary".parse().unwrap());
357
358        assert_eq!(
359            metadata.get("authorization").unwrap().to_str().unwrap(),
360            "Bearer test-grpc-jwt"
361        );
362        assert_eq!(
363            metadata.get("x-chalk-env-id").unwrap().to_str().unwrap(),
364            "env-1"
365        );
366        assert_eq!(
367            metadata.get("x-chalk-branch-id").unwrap().to_str().unwrap(),
368            "branch-42"
369        );
370        assert_eq!(
371            metadata
372                .get("x-chalk-deployment-tag")
373                .unwrap()
374                .to_str()
375                .unwrap(),
376            "canary"
377        );
378    }
379}