Skip to main content

chalk_client/
http_client.rs

1//! HTTP/REST client for the Chalk feature store.
2//!
3//! [`ChalkClient`] is the main entry point for most users. It talks to the
4//! Chalk API over HTTP/JSON (for `query`, `offline_query`) and HTTP with
5//! Arrow IPC binary bodies (for `query_bulk`, `upload_features`).
6//!
7//! ## Builder pattern
8//!
9//! You construct a `ChalkClient` using the builder, which mirrors the config
10//! builder but also performs the initial token exchange:
11//!
12//! ```rust,no_run
13//! use chalk_client::ChalkClient;
14//!
15//! # async fn example() -> chalk_client::error::Result<()> {
16//! let client = ChalkClient::new()
17//!     .client_id("my-client-id")
18//!     .client_secret("my-secret")
19//!     .environment("production")
20//!     .build()
21//!     .await?;
22//!
23//! // Now use it:
24//! // let response = client.query(inputs, outputs, options).await?;
25//! # Ok(())
26//! # }
27//! ```
28
29use std::collections::HashMap;
30use std::time::Duration;
31
32use arrow::ipc::writer::FileWriter;
33use arrow::record_batch::RecordBatch;
34use serde::Serialize;
35use serde_json::Value;
36
37use crate::auth::TokenManager;
38use crate::config::{ChalkClientConfig, ChalkClientConfigBuilder, ensure_scheme};
39use crate::error::{ChalkClientError, Result};
40use crate::offline::OfflineQueryParams;
41use crate::types::{
42    FeatureEncodingOptions, GetOfflineQueryJobResponse, GetOfflineQueryStatusResponse,
43    OfflineQueryRequest, OfflineQueryResponse, OnlineQueryContext, OnlineQueryRequest,
44    OnlineQueryResponse, QueryOptions, UploadFeaturesResult,
45};
46
47/// The User-Agent string we send with every request.
48const USER_AGENT: &str = "chalk-rust/0.1.0";
49
50/// Magic string that marks the start of a multi-query feather request.
51const MULTI_QUERY_MAGIC_STR: &[u8] = b"chal1";
52
53/// Magic string that marks the start of a ByteBaseModel response.
54const BYTEMODEL_MAGIC_STR: &[u8] = b"CHALK_BYTE_TRANSMISSION";
55
56// =========================================================================
57// ChalkClient
58// =========================================================================
59
60/// An HTTP/REST client for the Chalk feature store.
61pub struct ChalkClient {
62    /// The resolved configuration.
63    config: ChalkClientConfig,
64
65    /// Manages JWT tokens (exchange + caching).
66    token_manager: TokenManager,
67
68    /// The underlying HTTP client (connection pooling, TLS, etc.).
69    http_client: reqwest::Client,
70
71    /// The resolved query server URL.
72    query_server: String,
73
74    /// The resolved environment ID.
75    environment_id: String,
76}
77
78// =========================================================================
79// Builder
80// =========================================================================
81
82/// Builder for [`ChalkClient`].
83pub struct ChalkClientBuilder {
84    config_builder: ChalkClientConfigBuilder,
85}
86
87impl ChalkClient {
88    /// Start building a new `ChalkClient`.
89    #[allow(clippy::new_ret_no_self)]
90    pub fn new() -> ChalkClientBuilder {
91        ChalkClientBuilder {
92            config_builder: ChalkClientConfigBuilder::new(),
93        }
94    }
95}
96
97impl ChalkClientBuilder {
98    /// Set the OAuth2 client ID.
99    pub fn client_id(mut self, id: impl Into<String>) -> Self {
100        self.config_builder = self.config_builder.client_id(id);
101        self
102    }
103
104    /// Set the OAuth2 client secret.
105    pub fn client_secret(mut self, secret: impl Into<String>) -> Self {
106        self.config_builder = self.config_builder.client_secret(secret);
107        self
108    }
109
110    /// Set the API server URL.
111    pub fn api_server(mut self, url: impl Into<String>) -> Self {
112        self.config_builder = self.config_builder.api_server(url);
113        self
114    }
115
116    /// Set the target environment.
117    pub fn environment(mut self, env: impl Into<String>) -> Self {
118        self.config_builder = self.config_builder.environment(env);
119        self
120    }
121
122    /// Set the branch ID.
123    pub fn branch_id(mut self, id: impl Into<String>) -> Self {
124        self.config_builder = self.config_builder.branch_id(id);
125        self
126    }
127
128    /// Set the deployment tag.
129    pub fn deployment_tag(mut self, tag: impl Into<String>) -> Self {
130        self.config_builder = self.config_builder.deployment_tag(tag);
131        self
132    }
133
134    /// Set the query server URL directly.
135    pub fn query_server(mut self, url: impl Into<String>) -> Self {
136        self.config_builder = self.config_builder.query_server(url);
137        self
138    }
139
140    /// Build the client.
141    ///
142    /// This is `async` because it performs the initial token exchange to
143    /// discover the query engine URL and validate credentials.
144    pub async fn build(self) -> Result<ChalkClient> {
145        let config = self.config_builder.build()?;
146
147        let token_manager = TokenManager::new(config.clone());
148        let token = token_manager.get_token().await?;
149
150        let environment_id = config
151            .environment
152            .clone()
153            .or(token.primary_environment.clone())
154            .ok_or_else(|| {
155                ChalkClientError::Config(
156                    "no environment specified and token has no primary_environment".into(),
157                )
158            })?;
159
160        let query_server = ensure_scheme(
161            config
162                .query_server
163                .clone()
164                .or_else(|| token.engines.get(&environment_id).cloned())
165                .unwrap_or_else(|| config.api_server.clone()),
166        );
167
168        tracing::info!(
169            environment = %environment_id,
170            query_server = %query_server,
171            "ChalkClient initialized"
172        );
173
174        Ok(ChalkClient {
175            config,
176            token_manager,
177            http_client: reqwest::Client::new(),
178            query_server,
179            environment_id,
180        })
181    }
182}
183
184// =========================================================================
185// Query methods
186// =========================================================================
187
188impl ChalkClient {
189    /// Query features online (single entity, JSON request/response).
190    ///
191    /// # Arguments
192    ///
193    /// * `inputs` — Known feature values, e.g. `{"user.id": 42}`.
194    /// * `outputs` — Which features to compute, e.g. `["user.age", "user.name"]`.
195    /// * `options` — Optional settings (staleness, tags, etc.).
196    pub async fn query(
197        &self,
198        inputs: HashMap<String, Value>,
199        outputs: Vec<String>,
200        options: QueryOptions,
201    ) -> Result<OnlineQueryResponse> {
202        let url = format!("{}/v1/query/online", self.engine_url());
203
204        let body = OnlineQueryRequest {
205            inputs,
206            outputs,
207            context: options.context,
208            staleness: options.staleness,
209            include_meta: options.include_meta,
210            query_name: options.query_name,
211            correlation_id: options.correlation_id,
212            query_context: options.query_context,
213            meta: options.meta,
214            query_name_version: options.query_name_version,
215            now: options.now,
216            explain: options.explain,
217            store_plan_stages: options.store_plan_stages,
218            encoding_options: options.encoding_options,
219            branch_id: options.branch_id.or(self.config.branch_id.clone()),
220        };
221
222        let resp = self
223            .send_json_request(reqwest::Method::POST, &url, &body)
224            .await?;
225
226        let status = resp.status();
227        let body_text = resp.text().await?;
228
229        if !status.is_success() {
230            return Err(ChalkClientError::Api {
231                status: status.as_u16(),
232                message: body_text,
233            });
234        }
235
236        let response: OnlineQueryResponse = serde_json::from_str(&body_text)?;
237
238        if !response.errors.is_empty() {
239            tracing::warn!(
240                error_count = response.errors.len(),
241                "query returned server errors"
242            );
243        }
244
245        Ok(response)
246    }
247
248    /// Query features in bulk using the Chalk feather protocol.
249    ///
250    /// You provide inputs as an Arrow `RecordBatch` (one column per input
251    /// feature, one row per entity) and get back a `BulkQueryResult` containing
252    /// the output features as raw Feather bytes.
253    pub async fn query_bulk(
254        &self,
255        inputs: &RecordBatch,
256        outputs: Vec<String>,
257        options: QueryOptions,
258    ) -> Result<BulkQueryResult> {
259        let url = format!("{}/v1/query/feather", self.engine_url());
260
261        let header = FeatherRequestHeader {
262            outputs: outputs.clone(),
263            expression_outputs: vec![],
264            now: None,
265            staleness: options.staleness,
266            context: options.context,
267            include_meta: options.include_meta.unwrap_or(true),
268            explain: options.explain.unwrap_or(false),
269            correlation_id: options.correlation_id,
270            query_name: options.query_name,
271            query_name_version: options.query_name_version,
272            deployment_id: None,
273            branch_id: options.branch_id.or(self.config.branch_id.clone()),
274            meta: options.meta,
275            store_plan_stages: options.store_plan_stages.or(Some(false)),
276            query_context: options.query_context,
277            encoding_options: options
278                .encoding_options
279                .unwrap_or(FeatureEncodingOptions {
280                    encode_structs_as_objects: None,
281                }),
282            planner_options: options.planner_options,
283            value_metrics_tag_by_features: vec![],
284            overlay_graph: None,
285        };
286
287        let feather_bytes = serialize_record_batch_to_feather(inputs)?;
288
289        let request_body = build_feather_request_body(&header, &feather_bytes)?;
290
291        let token = self.token_manager.get_token().await?;
292
293        let deployment_type = if self.config.branch_id.is_some() {
294            "branch"
295        } else {
296            "engine"
297        };
298
299        let mut request = self
300            .http_client
301            .post(&url)
302            .header("Authorization", format!("Bearer {}", token.access_token))
303            .header("User-Agent", USER_AGENT)
304            .header("Content-Type", "application/octet-stream")
305            .header("Accept", "application/octet-stream")
306            .header("X-Chalk-Client-Id", &self.config.client_id)
307            .header("X-Chalk-Env-Id", &self.environment_id)
308            .header("X-Chalk-Deployment-Type", deployment_type)
309            .header("X-Chalk-Features-Versioned", "true");
310
311        if let Some(ref branch) = self.config.branch_id {
312            request = request.header("X-Chalk-Branch-Id", branch.as_str());
313        }
314        if let Some(ref tag) = self.config.deployment_tag {
315            request = request.header("X-Chalk-Deployment-Tag", tag);
316        }
317
318        let resp = request.body(request_body).send().await?;
319
320        let status = resp.status();
321        if !status.is_success() {
322            let body = resp.text().await.unwrap_or_default();
323            return Err(ChalkClientError::Api {
324                status: status.as_u16(),
325                message: body,
326            });
327        }
328
329        let response_bytes = resp.bytes().await?;
330        parse_bulk_query_response(&response_bytes)
331    }
332
333    /// Run an offline query using the builder pattern.
334    ///
335    /// # Example
336    ///
337    /// ```rust,no_run
338    /// # use chalk_client::{ChalkClient, OfflineQueryParams};
339    /// # async fn example(client: &ChalkClient) -> chalk_client::error::Result<()> {
340    /// let response = client.offline_query(
341    ///     OfflineQueryParams::new()
342    ///         .with_input("user.id", vec![serde_json::json!(1), serde_json::json!(2)])
343    ///         .with_output("user.email")
344    /// ).await?;
345    /// # Ok(())
346    /// # }
347    /// ```
348    pub async fn offline_query(
349        &self,
350        params: OfflineQueryParams,
351    ) -> Result<OfflineQueryResponse> {
352        let request = params.build()?;
353        self.offline_query_raw(request).await
354    }
355
356    /// Run an offline query with a raw [`OfflineQueryRequest`].
357    pub async fn offline_query_raw(
358        &self,
359        request: OfflineQueryRequest,
360    ) -> Result<OfflineQueryResponse> {
361        let url = format!("{}/v4/offline_query", self.config.api_server);
362
363        let resp = self
364            .send_json_request(reqwest::Method::POST, &url, &request)
365            .await?;
366
367        let status = resp.status();
368        let body_text = resp.text().await?;
369
370        if !status.is_success() {
371            return Err(ChalkClientError::Api {
372                status: status.as_u16(),
373                message: body_text,
374            });
375        }
376
377        let response: OfflineQueryResponse = serde_json::from_str(&body_text)?;
378        Ok(response)
379    }
380
381    /// Get the status of an offline query job.
382    pub async fn get_offline_query_status(
383        &self,
384        job_id: &str,
385    ) -> Result<GetOfflineQueryStatusResponse> {
386        let url = format!(
387            "{}/v4/offline_query/{}/status",
388            self.config.api_server, job_id
389        );
390
391        let resp = self
392            .send_get_request(&url)
393            .await?;
394
395        let status = resp.status();
396        let body_text = resp.text().await?;
397
398        if !status.is_success() {
399            return Err(ChalkClientError::Api {
400                status: status.as_u16(),
401                message: body_text,
402            });
403        }
404
405        let response: GetOfflineQueryStatusResponse = serde_json::from_str(&body_text)?;
406        Ok(response)
407    }
408
409    /// Wait for an offline query job to complete.
410    ///
411    /// Polls [`get_offline_query_status`](Self::get_offline_query_status) every
412    /// second until the job reaches `"COMPLETED"` or `"FAILED"` status.
413    pub async fn wait_for_offline_query(
414        &self,
415        response: &OfflineQueryResponse,
416        timeout: Option<Duration>,
417    ) -> Result<()> {
418        let revision = response
419            .revisions
420            .last()
421            .and_then(|r| r.revision_id.as_deref())
422            .ok_or_else(|| {
423                ChalkClientError::Config("offline query response has no revision ID".into())
424            })?;
425
426        let poll_fut = async {
427            loop {
428                let status_resp = self.get_offline_query_status(revision).await?;
429                let report = match status_resp.report {
430                    Some(r) => r,
431                    None => {
432                        tokio::time::sleep(Duration::from_secs(1)).await;
433                        continue;
434                    }
435                };
436                let status = report.status.as_deref().unwrap_or("UNKNOWN");
437
438                match status {
439                    "COMPLETED" => return Ok(()),
440                    "FAILED" => {
441                        let errors = report.all_errors;
442                        if errors.is_empty() {
443                            if let Some(err) = report.error {
444                                return Err(ChalkClientError::ServerErrors(vec![err]));
445                            }
446                            return Err(ChalkClientError::Api {
447                                status: 0,
448                                message: "offline query failed with no error details".into(),
449                            });
450                        }
451                        return Err(ChalkClientError::ServerErrors(errors));
452                    }
453                    _ => {
454                        tokio::time::sleep(Duration::from_secs(1)).await;
455                    }
456                }
457            }
458        };
459
460        if let Some(timeout_dur) = timeout {
461            tokio::time::timeout(timeout_dur, poll_fut)
462                .await
463                .map_err(|_| {
464                    ChalkClientError::Api {
465                        status: 0,
466                        message: format!(
467                            "timed out waiting for offline query after {:?}",
468                            timeout_dur
469                        ),
470                    }
471                })?
472        } else {
473            poll_fut.await
474        }
475    }
476
477    /// Get download URLs for an offline query's result Parquet files.
478    pub async fn get_offline_query_download_urls(
479        &self,
480        response: &OfflineQueryResponse,
481        timeout: Option<Duration>,
482    ) -> Result<Vec<String>> {
483        let revision_id = response
484            .revisions
485            .last()
486            .and_then(|r| r.revision_id.as_deref())
487            .ok_or_else(|| {
488                ChalkClientError::Config("offline query response has no revision ID".into())
489            })?;
490
491        let poll_fut = async {
492            loop {
493                let url = format!(
494                    "{}/v2/offline_query/{}",
495                    self.config.api_server, revision_id
496                );
497
498                let resp = self.send_get_request(&url).await?;
499                let status = resp.status();
500                let body_text = resp.text().await?;
501
502                if !status.is_success() {
503                    return Err(ChalkClientError::Api {
504                        status: status.as_u16(),
505                        message: body_text,
506                    });
507                }
508
509                let job_resp: GetOfflineQueryJobResponse = serde_json::from_str(&body_text)?;
510
511                if job_resp.is_finished {
512                    if !job_resp.errors.is_empty() {
513                        return Err(ChalkClientError::ServerErrors(job_resp.errors));
514                    }
515                    return Ok(job_resp.urls);
516                }
517
518                tokio::time::sleep(Duration::from_millis(500)).await;
519            }
520        };
521
522        if let Some(timeout_dur) = timeout {
523            tokio::time::timeout(timeout_dur, poll_fut)
524                .await
525                .map_err(|_| {
526                    ChalkClientError::Api {
527                        status: 0,
528                        message: format!(
529                            "timed out waiting for download URLs after {:?}",
530                            timeout_dur
531                        ),
532                    }
533                })?
534        } else {
535            poll_fut.await
536        }
537    }
538
539    /// Upload feature values to the Chalk feature store.
540    ///
541    /// # Arguments
542    ///
543    /// * `features` — An Arrow RecordBatch where each column is a feature
544    ///   (column names are feature FQNs like `"user.age"`).
545    pub async fn upload_features(
546        &self,
547        features: &RecordBatch,
548    ) -> Result<UploadFeaturesResult> {
549        let url = format!("{}/v1/upload_features/multi", self.engine_url());
550
551        let feature_names: Vec<String> = features
552            .schema()
553            .fields()
554            .iter()
555            .map(|f| f.name().clone())
556            .collect();
557
558        let feather_bytes = serialize_record_batch_to_feather(features)?;
559
560        let json_attrs = serde_json::json!({
561            "features": feature_names,
562            "table_compression": "uncompressed",
563        });
564        let body = build_byte_base_model(&json_attrs, &[("table_bytes", &feather_bytes)])?;
565
566        let token = self.token_manager.get_token().await?;
567
568        let deployment_type = if self.config.branch_id.is_some() {
569            "branch"
570        } else {
571            "engine"
572        };
573
574        let mut request = self
575            .http_client
576            .post(&url)
577            .header("Authorization", format!("Bearer {}", token.access_token))
578            .header("User-Agent", USER_AGENT)
579            .header("Content-Type", "application/octet-stream")
580            .header("Accept", "application/json")
581            .header("X-Chalk-Client-Id", &self.config.client_id)
582            .header("X-Chalk-Env-Id", &self.environment_id)
583            .header("X-Chalk-Deployment-Type", deployment_type)
584            .header("X-Chalk-Features-Versioned", "true");
585
586        if let Some(ref branch) = self.config.branch_id {
587            request = request.header("X-Chalk-Branch-Id", branch.as_str());
588        }
589        if let Some(ref tag) = self.config.deployment_tag {
590            request = request.header("X-Chalk-Deployment-Tag", tag);
591        }
592
593        let resp = request.body(body).send().await?;
594
595        let status = resp.status();
596        let body_text = resp.text().await?;
597
598        if !status.is_success() {
599            return Err(ChalkClientError::Api {
600                status: status.as_u16(),
601                message: body_text,
602            });
603        }
604
605        let result: UploadFeaturesResult = serde_json::from_str(&body_text)?;
606
607        if !result.errors.is_empty() {
608            tracing::warn!(
609                error_count = result.errors.len(),
610                "upload_features returned server errors"
611            );
612        }
613
614        Ok(result)
615    }
616
617    /// Upload feature values from a map of feature names to value arrays.
618    ///
619    /// # Example
620    ///
621    /// ```rust,no_run
622    /// # use chalk_client::ChalkClient;
623    /// # use std::collections::HashMap;
624    /// # async fn example(client: &ChalkClient) -> chalk_client::error::Result<()> {
625    /// let inputs = HashMap::from([
626    ///     ("user.id".to_string(), vec![serde_json::json!(1), serde_json::json!(2)]),
627    ///     ("user.name".to_string(), vec![serde_json::json!("Alice"), serde_json::json!("Bob")]),
628    /// ]);
629    /// let result = client.upload_features_map(inputs).await?;
630    /// # Ok(())
631    /// # }
632    /// ```
633    pub async fn upload_features_map(
634        &self,
635        inputs: HashMap<String, Vec<Value>>,
636    ) -> Result<UploadFeaturesResult> {
637        use arrow::array::StringArray;
638        use arrow::datatypes::{DataType, Field, Schema};
639        use std::sync::Arc;
640
641        if inputs.is_empty() {
642            return Err(ChalkClientError::Config(
643                "upload_features_map requires at least one feature".into(),
644            ));
645        }
646
647        let mut feature_names: Vec<String> = inputs.keys().cloned().collect();
648        feature_names.sort();
649
650        let num_rows = inputs[&feature_names[0]].len();
651
652        let fields: Vec<Field> = feature_names
653            .iter()
654            .map(|name| Field::new(name, DataType::Utf8, true))
655            .collect();
656        let schema = Arc::new(Schema::new(fields));
657
658        let columns: Vec<Arc<dyn arrow::array::Array>> = feature_names
659            .iter()
660            .map(|name| {
661                let values = &inputs[name];
662                let strings: Vec<Option<String>> = values
663                    .iter()
664                    .map(|v| match v {
665                        Value::Null => None,
666                        Value::String(s) => Some(s.clone()),
667                        other => Some(other.to_string()),
668                    })
669                    .collect();
670                Arc::new(StringArray::from(strings)) as Arc<dyn arrow::array::Array>
671            })
672            .collect();
673
674        let batch = RecordBatch::try_new(schema, columns).map_err(|e| {
675            ChalkClientError::Arrow(e)
676        })?;
677
678        if batch.num_rows() != num_rows {
679            return Err(ChalkClientError::Config(
680                "all input arrays must be the same length".into(),
681            ));
682        }
683
684        self.upload_features(&batch).await
685    }
686
687    /// Returns the resolved environment ID.
688    pub fn environment_id(&self) -> &str {
689        &self.environment_id
690    }
691
692    /// Returns the resolved query server URL.
693    pub fn query_server(&self) -> &str {
694        &self.query_server
695    }
696
697    // =====================================================================
698    // Internal helpers
699    // =====================================================================
700
701    fn engine_url(&self) -> &str {
702        if self.config.branch_id.is_some() {
703            &self.config.api_server
704        } else {
705            &self.query_server
706        }
707    }
708
709    async fn send_json_request<T: serde::Serialize>(
710        &self,
711        method: reqwest::Method,
712        url: &str,
713        body: &T,
714    ) -> Result<reqwest::Response> {
715        let token = self.token_manager.get_token().await?;
716
717        let deployment_type = if self.config.branch_id.is_some() {
718            "branch"
719        } else {
720            "engine"
721        };
722
723        let mut request = self
724            .http_client
725            .request(method, url)
726            .header("Authorization", format!("Bearer {}", token.access_token))
727            .header("Content-Type", "application/json")
728            .header("Accept", "application/json")
729            .header("User-Agent", USER_AGENT)
730            .header("X-Chalk-Client-Id", &self.config.client_id)
731            .header("X-Chalk-Env-Id", &self.environment_id)
732            .header("X-Chalk-Deployment-Type", deployment_type)
733            .header("X-Chalk-Features-Versioned", "true");
734
735        if let Some(ref branch) = self.config.branch_id {
736            request = request.header("X-Chalk-Branch-Id", branch.as_str());
737        }
738        if let Some(ref tag) = self.config.deployment_tag {
739            request = request.header("X-Chalk-Deployment-Tag", tag);
740        }
741
742        let resp = request.json(body).send().await?;
743        Ok(resp)
744    }
745
746    async fn send_get_request(&self, url: &str) -> Result<reqwest::Response> {
747        let token = self.token_manager.get_token().await?;
748
749        let deployment_type = if self.config.branch_id.is_some() {
750            "branch"
751        } else {
752            "engine"
753        };
754
755        let mut request = self
756            .http_client
757            .get(url)
758            .header("Authorization", format!("Bearer {}", token.access_token))
759            .header("Accept", "application/json")
760            .header("User-Agent", USER_AGENT)
761            .header("X-Chalk-Client-Id", &self.config.client_id)
762            .header("X-Chalk-Env-Id", &self.environment_id)
763            .header("X-Chalk-Deployment-Type", deployment_type)
764            .header("X-Chalk-Features-Versioned", "true");
765
766        if let Some(ref branch) = self.config.branch_id {
767            request = request.header("X-Chalk-Branch-Id", branch.as_str());
768        }
769        if let Some(ref tag) = self.config.deployment_tag {
770            request = request.header("X-Chalk-Deployment-Tag", tag);
771        }
772
773        let resp = request.send().await?;
774        Ok(resp)
775    }
776}
777
778// =========================================================================
779// Feather request protocol types
780// =========================================================================
781
782#[derive(Debug, Serialize)]
783struct FeatherRequestHeader {
784    outputs: Vec<String>,
785    #[serde(default)]
786    expression_outputs: Vec<String>,
787    #[serde(skip_serializing_if = "Option::is_none")]
788    now: Option<Vec<String>>,
789    #[serde(skip_serializing_if = "Option::is_none")]
790    staleness: Option<HashMap<String, String>>,
791    #[serde(skip_serializing_if = "Option::is_none")]
792    context: Option<OnlineQueryContext>,
793    include_meta: bool,
794    explain: bool,
795    #[serde(skip_serializing_if = "Option::is_none")]
796    correlation_id: Option<String>,
797    #[serde(skip_serializing_if = "Option::is_none")]
798    query_name: Option<String>,
799    #[serde(skip_serializing_if = "Option::is_none")]
800    query_name_version: Option<String>,
801    #[serde(skip_serializing_if = "Option::is_none")]
802    deployment_id: Option<String>,
803    #[serde(skip_serializing_if = "Option::is_none")]
804    branch_id: Option<String>,
805    #[serde(skip_serializing_if = "Option::is_none")]
806    meta: Option<HashMap<String, String>>,
807    #[serde(skip_serializing_if = "Option::is_none")]
808    store_plan_stages: Option<bool>,
809    #[serde(skip_serializing_if = "Option::is_none")]
810    query_context: Option<HashMap<String, Value>>,
811    encoding_options: FeatureEncodingOptions,
812    #[serde(skip_serializing_if = "Option::is_none")]
813    planner_options: Option<HashMap<String, Value>>,
814    #[serde(default)]
815    value_metrics_tag_by_features: Vec<String>,
816    #[serde(skip_serializing_if = "Option::is_none")]
817    overlay_graph: Option<String>,
818}
819
820// =========================================================================
821// Bulk query response types
822// =========================================================================
823
824/// The result of a bulk (feather) query.
825#[derive(Debug)]
826pub struct BulkQueryResult {
827    /// The output features as raw Feather (Arrow IPC file) bytes.
828    pub scalar_data: Vec<u8>,
829
830    /// Whether the server indicated it has data.
831    pub has_data: bool,
832
833    /// JSON-stringified query metadata.
834    pub meta: Option<String>,
835
836    /// JSON-stringified error objects from the server.
837    pub errors: Vec<String>,
838}
839
840// =========================================================================
841// Feather request serialization
842// =========================================================================
843
844fn build_feather_request_body(header: &FeatherRequestHeader, feather_bytes: &[u8]) -> Result<Vec<u8>> {
845    let header_json = serde_json::to_string(header)?;
846    let header_bytes = header_json.as_bytes();
847
848    let total_size = 5 + 8 + header_bytes.len() + 8 + feather_bytes.len();
849    let mut buf = Vec::with_capacity(total_size);
850
851    buf.extend_from_slice(MULTI_QUERY_MAGIC_STR);
852
853    buf.extend_from_slice(&(header_bytes.len() as u64).to_be_bytes());
854    buf.extend_from_slice(header_bytes);
855
856    buf.extend_from_slice(&(feather_bytes.len() as u64).to_be_bytes());
857    buf.extend_from_slice(feather_bytes);
858
859    Ok(buf)
860}
861
862// =========================================================================
863// ByteBaseModel response parsing
864// =========================================================================
865
866fn parse_bulk_query_response(data: &[u8]) -> Result<BulkQueryResult> {
867    let mut pos: usize = 0;
868
869    pos = consume_magic(data, pos)?;
870
871    let (new_pos, _attrs_json) = read_length_prefixed_json(data, pos)?;
872    pos = new_pos;
873
874    let (new_pos, _pydantic_json) = read_length_prefixed_json(data, pos)?;
875    pos = new_pos;
876
877    let (new_pos, byte_offset_map) = read_length_prefixed_json(data, pos)?;
878    pos = new_pos;
879    pos = skip_byte_data(data, pos, &byte_offset_map)?;
880
881    let (new_pos, serializable_offset_map) = read_length_prefixed_json(data, pos)?;
882    pos = new_pos;
883
884    let query_results_len = serializable_offset_map
885        .get("query_results_bytes")
886        .and_then(|v| v.as_u64())
887        .ok_or_else(|| ChalkClientError::Api {
888            status: 0,
889            message: format!(
890                "missing query_results_bytes in serializable_attrs (got: {})",
891                serializable_offset_map
892            ),
893        })? as usize;
894
895    if pos + query_results_len > data.len() {
896        return Err(ChalkClientError::Api {
897            status: 0,
898            message: "response truncated: query_results_bytes extends beyond data".into(),
899        });
900    }
901    let query_results_bytes = &data[pos..pos + query_results_len];
902
903    parse_query_result_feather(query_results_bytes)
904}
905
906fn parse_query_result_feather(data: &[u8]) -> Result<BulkQueryResult> {
907    let mut pos: usize = 0;
908
909    pos = consume_magic(data, pos)?;
910
911    let (new_pos, _) = read_length_prefixed_json(data, pos)?;
912    pos = new_pos;
913
914    let (new_pos, _) = read_length_prefixed_json(data, pos)?;
915    pos = new_pos;
916
917    let (new_pos, byte_offset_map) = read_length_prefixed_json(data, pos)?;
918    pos = new_pos;
919
920    let (_query_key, result_len) = byte_offset_map
921        .as_object()
922        .and_then(|m| m.iter().next())
923        .and_then(|(k, v)| v.as_u64().map(|len| (k.clone(), len as usize)))
924        .ok_or_else(|| ChalkClientError::Api {
925            status: 0,
926            message: "empty byte_attrs in query results ByteDict".into(),
927        })?;
928
929    if pos + result_len > data.len() {
930        return Err(ChalkClientError::Api {
931            status: 0,
932            message: "response truncated: result bytes extend beyond data".into(),
933        });
934    }
935    let result_bytes = &data[pos..pos + result_len];
936
937    parse_online_query_result_feather(result_bytes)
938}
939
940fn parse_online_query_result_feather(data: &[u8]) -> Result<BulkQueryResult> {
941    let mut pos: usize = 0;
942
943    pos = consume_magic(data, pos)?;
944
945    let (new_pos, json_attrs) = read_length_prefixed_json(data, pos)?;
946    pos = new_pos;
947
948    let has_data = json_attrs
949        .get("has_data")
950        .and_then(|v| v.as_bool())
951        .unwrap_or(false);
952
953    let meta = json_attrs
954        .get("meta")
955        .and_then(|v| v.as_str())
956        .map(|s| s.to_string());
957
958    let errors: Vec<String> = json_attrs
959        .get("errors")
960        .and_then(|v| v.as_array())
961        .map(|arr| {
962            arr.iter()
963                .filter_map(|v| v.as_str().map(|s| s.to_string()))
964                .collect()
965        })
966        .unwrap_or_default();
967
968    let (new_pos, _) = read_length_prefixed_json(data, pos)?;
969    pos = new_pos;
970
971    let (new_pos, byte_offset_map) = read_length_prefixed_json(data, pos)?;
972    pos = new_pos;
973
974    let scalar_data_len = byte_offset_map
975        .get("scalar_data")
976        .and_then(|v| v.as_u64())
977        .unwrap_or(0) as usize;
978
979    let scalar_data = if scalar_data_len > 0 && pos + scalar_data_len <= data.len() {
980        data[pos..pos + scalar_data_len].to_vec()
981    } else {
982        vec![]
983    };
984
985    Ok(BulkQueryResult {
986        scalar_data,
987        has_data,
988        meta,
989        errors,
990    })
991}
992
993fn consume_magic(data: &[u8], pos: usize) -> Result<usize> {
994    if pos + BYTEMODEL_MAGIC_STR.len() > data.len() {
995        return Err(ChalkClientError::Api {
996            status: 0,
997            message: format!(
998                "response too short for magic string at position {} ({} bytes available)",
999                pos,
1000                data.len() - pos
1001            ),
1002        });
1003    }
1004    if &data[pos..pos + BYTEMODEL_MAGIC_STR.len()] != BYTEMODEL_MAGIC_STR {
1005        return Err(ChalkClientError::Api {
1006            status: 0,
1007            message: format!(
1008                "invalid ByteBaseModel magic at position {} (got {:?})",
1009                pos,
1010                &data[pos..std::cmp::min(pos + BYTEMODEL_MAGIC_STR.len(), data.len())]
1011            ),
1012        });
1013    }
1014    Ok(pos + BYTEMODEL_MAGIC_STR.len())
1015}
1016
1017fn skip_byte_data(data: &[u8], pos: usize, offset_map: &Value) -> Result<usize> {
1018    let total_bytes: usize = offset_map
1019        .as_object()
1020        .map(|m| {
1021            m.values()
1022                .filter_map(|v| v.as_u64())
1023                .map(|v| v as usize)
1024                .sum()
1025        })
1026        .unwrap_or(0);
1027
1028    if pos + total_bytes > data.len() {
1029        return Err(ChalkClientError::Api {
1030            status: 0,
1031            message: format!(
1032                "response truncated: byte data of {} bytes at position {} extends beyond data (total {})",
1033                total_bytes, pos, data.len()
1034            ),
1035        });
1036    }
1037
1038    Ok(pos + total_bytes)
1039}
1040
1041fn read_length_prefixed_json(data: &[u8], pos: usize) -> Result<(usize, Value)> {
1042    if pos + 8 > data.len() {
1043        return Err(ChalkClientError::Api {
1044            status: 0,
1045            message: format!(
1046                "response truncated: expected 8-byte length at position {}, but only {} bytes remain",
1047                pos,
1048                data.len() - pos
1049            ),
1050        });
1051    }
1052
1053    let len = u64::from_be_bytes(data[pos..pos + 8].try_into().unwrap()) as usize;
1054    let json_start = pos + 8;
1055
1056    if json_start + len > data.len() {
1057        return Err(ChalkClientError::Api {
1058            status: 0,
1059            message: format!(
1060                "response truncated: JSON payload of {} bytes at position {} extends beyond data (total {})",
1061                len, json_start, data.len()
1062            ),
1063        });
1064    }
1065
1066    let json_str = std::str::from_utf8(&data[json_start..json_start + len]).map_err(|e| {
1067        ChalkClientError::Api {
1068            status: 0,
1069            message: format!("invalid UTF-8 in response JSON: {}", e),
1070        }
1071    })?;
1072
1073    let value: Value = serde_json::from_str(json_str)?;
1074    Ok((json_start + len, value))
1075}
1076
1077// =========================================================================
1078// ByteBaseModel serialization (request direction)
1079// =========================================================================
1080
1081fn build_byte_base_model(
1082    json_attrs: &Value,
1083    byte_attrs: &[(&str, &[u8])],
1084) -> Result<Vec<u8>> {
1085    let json_attrs_bytes = serde_json::to_vec(json_attrs)?;
1086    let empty_json = b"{}";
1087
1088    let byte_offset_map = {
1089        let mut map = serde_json::Map::new();
1090        for (key, data) in byte_attrs {
1091            map.insert((*key).to_string(), Value::Number((data.len() as u64).into()));
1092        }
1093        serde_json::to_vec(&Value::Object(map))?
1094    };
1095
1096    let total_byte_data: usize = byte_attrs.iter().map(|(_, d)| d.len()).sum();
1097
1098    let total_size = BYTEMODEL_MAGIC_STR.len()
1099        + 4 * 8
1100        + json_attrs_bytes.len()
1101        + empty_json.len()
1102        + byte_offset_map.len()
1103        + total_byte_data
1104        + empty_json.len();
1105    let mut buf = Vec::with_capacity(total_size);
1106
1107    buf.extend_from_slice(BYTEMODEL_MAGIC_STR);
1108
1109    buf.extend_from_slice(&(json_attrs_bytes.len() as u64).to_be_bytes());
1110    buf.extend_from_slice(&json_attrs_bytes);
1111
1112    buf.extend_from_slice(&(empty_json.len() as u64).to_be_bytes());
1113    buf.extend_from_slice(empty_json);
1114
1115    buf.extend_from_slice(&(byte_offset_map.len() as u64).to_be_bytes());
1116    buf.extend_from_slice(&byte_offset_map);
1117    for (_, data) in byte_attrs {
1118        buf.extend_from_slice(data);
1119    }
1120
1121    buf.extend_from_slice(&(empty_json.len() as u64).to_be_bytes());
1122    buf.extend_from_slice(empty_json);
1123
1124    Ok(buf)
1125}
1126
1127// =========================================================================
1128// Arrow serialization helpers
1129// =========================================================================
1130
1131fn serialize_record_batch_to_feather(batch: &RecordBatch) -> Result<Vec<u8>> {
1132    let mut buf = Vec::new();
1133
1134    {
1135        let mut writer = FileWriter::try_new(&mut buf, &batch.schema())?;
1136        writer.write(batch)?;
1137        writer.finish()?;
1138    }
1139
1140    Ok(buf)
1141}
1142
1143// =========================================================================
1144// Unit tests
1145// =========================================================================
1146#[cfg(test)]
1147mod tests {
1148    use super::*;
1149    use arrow::array::Int32Array;
1150    use arrow::datatypes::{DataType, Field, Schema};
1151    use std::sync::Arc;
1152
1153    #[test]
1154    fn test_serialize_record_batch_to_feather() {
1155        let schema = Arc::new(Schema::new(vec![Field::new(
1156            "user.id",
1157            DataType::Int32,
1158            false,
1159        )]));
1160        let batch =
1161            RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]).unwrap();
1162
1163        let feather_bytes = serialize_record_batch_to_feather(&batch).unwrap();
1164        assert!(!feather_bytes.is_empty());
1165        assert_eq!(&feather_bytes[..6], b"ARROW1");
1166    }
1167
1168    #[test]
1169    fn test_build_feather_request_body() {
1170        let header = FeatherRequestHeader {
1171            outputs: vec!["user.id".into()],
1172            expression_outputs: vec![],
1173            now: None,
1174            staleness: None,
1175            context: None,
1176            include_meta: true,
1177            explain: false,
1178            correlation_id: None,
1179            query_name: None,
1180            query_name_version: None,
1181            deployment_id: None,
1182            branch_id: None,
1183            meta: None,
1184            store_plan_stages: Some(false),
1185            query_context: None,
1186            encoding_options: FeatureEncodingOptions {
1187                encode_structs_as_objects: None,
1188            },
1189            planner_options: None,
1190            value_metrics_tag_by_features: vec![],
1191            overlay_graph: None,
1192        };
1193
1194        let fake_feather = b"ARROW1fake_feather_data";
1195        let body = build_feather_request_body(&header, fake_feather).unwrap();
1196
1197        assert_eq!(&body[..5], b"chal1");
1198
1199        let header_len = u64::from_be_bytes(body[5..13].try_into().unwrap()) as usize;
1200        assert!(header_len > 0);
1201
1202        let header_json_str = std::str::from_utf8(&body[13..13 + header_len]).unwrap();
1203        let parsed: Value = serde_json::from_str(header_json_str).unwrap();
1204        assert_eq!(parsed["outputs"][0], "user.id");
1205        assert_eq!(parsed["include_meta"], true);
1206
1207        let body_len_start = 13 + header_len;
1208        let body_len =
1209            u64::from_be_bytes(body[body_len_start..body_len_start + 8].try_into().unwrap())
1210                as usize;
1211        assert_eq!(body_len, fake_feather.len());
1212
1213        let body_start = body_len_start + 8;
1214        assert_eq!(&body[body_start..body_start + body_len], fake_feather);
1215    }
1216
1217    #[tokio::test]
1218    async fn test_client_builder() {
1219        let mut server = mockito::Server::new_async().await;
1220
1221        let mock = server
1222            .mock("POST", "/v1/oauth/token")
1223            .with_status(200)
1224            .with_header("content-type", "application/json")
1225            .with_body(
1226                serde_json::json!({
1227                    "access_token": "test-jwt",
1228                    "expires_in": 3600,
1229                    "primary_environment": "env-123",
1230                    "engines": {"env-123": server.url()},
1231                    "grpc_engines": {},
1232                    "environment_id_to_name": {"env-123": "production"}
1233                })
1234                .to_string(),
1235            )
1236            .create_async()
1237            .await;
1238
1239        let client = ChalkClient::new()
1240            .client_id("test-id")
1241            .client_secret("test-secret")
1242            .api_server(&server.url())
1243            .environment("env-123")
1244            .build()
1245            .await
1246            .unwrap();
1247
1248        assert_eq!(client.environment_id(), "env-123");
1249        assert_eq!(client.query_server(), &server.url());
1250        mock.assert_async().await;
1251    }
1252
1253    #[tokio::test]
1254    async fn test_query() {
1255        let mut server = mockito::Server::new_async().await;
1256
1257        server
1258            .mock("POST", "/v1/oauth/token")
1259            .with_status(200)
1260            .with_header("content-type", "application/json")
1261            .with_body(
1262                serde_json::json!({
1263                    "access_token": "test-jwt",
1264                    "expires_in": 3600,
1265                    "primary_environment": "env-1",
1266                    "engines": {"env-1": server.url()},
1267                    "grpc_engines": {}
1268                })
1269                .to_string(),
1270            )
1271            .create_async()
1272            .await;
1273
1274        let query_mock = server
1275            .mock("POST", "/v1/query/online")
1276            .match_header("Authorization", "Bearer test-jwt")
1277            .match_header("X-Chalk-Env-Id", "env-1")
1278            .with_status(200)
1279            .with_header("content-type", "application/json")
1280            .with_body(
1281                serde_json::json!({
1282                    "data": [
1283                        {"field": "user.age", "value": 25},
1284                        {"field": "user.name", "value": "Alice"}
1285                    ],
1286                    "errors": []
1287                })
1288                .to_string(),
1289            )
1290            .create_async()
1291            .await;
1292
1293        let client = ChalkClient::new()
1294            .client_id("test-id")
1295            .client_secret("test-secret")
1296            .api_server(&server.url())
1297            .environment("env-1")
1298            .build()
1299            .await
1300            .unwrap();
1301
1302        let inputs = HashMap::from([("user.id".into(), serde_json::json!(42))]);
1303        let outputs = vec!["user.age".into(), "user.name".into()];
1304
1305        let response = client
1306            .query(inputs, outputs, QueryOptions::default())
1307            .await
1308            .unwrap();
1309
1310        assert_eq!(response.data.len(), 2);
1311        assert_eq!(response.data[0].field, "user.age");
1312        assert_eq!(response.data[0].value, serde_json::json!(25));
1313        assert_eq!(response.data[1].field, "user.name");
1314        assert_eq!(response.data[1].value, serde_json::json!("Alice"));
1315
1316        query_mock.assert_async().await;
1317    }
1318
1319    #[tokio::test]
1320    async fn test_query_api_error() {
1321        let mut server = mockito::Server::new_async().await;
1322
1323        server
1324            .mock("POST", "/v1/oauth/token")
1325            .with_status(200)
1326            .with_header("content-type", "application/json")
1327            .with_body(
1328                serde_json::json!({
1329                    "access_token": "jwt",
1330                    "expires_in": 3600,
1331                    "primary_environment": "env-1",
1332                    "engines": {"env-1": server.url()},
1333                    "grpc_engines": {}
1334                })
1335                .to_string(),
1336            )
1337            .create_async()
1338            .await;
1339
1340        server
1341            .mock("POST", "/v1/query/online")
1342            .with_status(500)
1343            .with_body("internal server error")
1344            .create_async()
1345            .await;
1346
1347        let client = ChalkClient::new()
1348            .client_id("id")
1349            .client_secret("secret")
1350            .api_server(&server.url())
1351            .environment("env-1")
1352            .build()
1353            .await
1354            .unwrap();
1355
1356        let result = client
1357            .query(HashMap::new(), vec![], QueryOptions::default())
1358            .await;
1359
1360        assert!(result.is_err());
1361        match result.unwrap_err() {
1362            ChalkClientError::Api { status, message } => {
1363                assert_eq!(status, 500);
1364                assert!(message.contains("internal server error"));
1365            }
1366            other => panic!("expected Api error, got: {:?}", other),
1367        }
1368    }
1369
1370    #[tokio::test]
1371    async fn test_offline_query() {
1372        let mut server = mockito::Server::new_async().await;
1373
1374        server
1375            .mock("POST", "/v1/oauth/token")
1376            .with_status(200)
1377            .with_header("content-type", "application/json")
1378            .with_body(
1379                serde_json::json!({
1380                    "access_token": "jwt",
1381                    "expires_in": 3600,
1382                    "primary_environment": "env-1",
1383                    "engines": {"env-1": server.url()},
1384                    "grpc_engines": {}
1385                })
1386                .to_string(),
1387            )
1388            .create_async()
1389            .await;
1390
1391        let offline_mock = server
1392            .mock("POST", "/v4/offline_query")
1393            .match_header("Authorization", "Bearer jwt")
1394            .with_status(200)
1395            .with_header("content-type", "application/json")
1396            .with_body(
1397                serde_json::json!({
1398                    "is_finished": false,
1399                    "dataset_id": "ds-123",
1400                    "revisions": [{
1401                        "revision_id": "rev-1",
1402                        "status": "pending"
1403                    }],
1404                    "errors": []
1405                })
1406                .to_string(),
1407            )
1408            .create_async()
1409            .await;
1410
1411        let client = ChalkClient::new()
1412            .client_id("id")
1413            .client_secret("secret")
1414            .api_server(&server.url())
1415            .environment("env-1")
1416            .build()
1417            .await
1418            .unwrap();
1419
1420        let request = OfflineQueryRequest {
1421            input: None,
1422            output: vec!["user.ltv".into()],
1423            destination_format: Some("PARQUET".into()),
1424            job_id: None,
1425            max_samples: None,
1426            max_cache_age_secs: None,
1427            observed_at_lower_bound: None,
1428            observed_at_upper_bound: None,
1429            dataset_name: None,
1430            branch: None,
1431            recompute_features: None,
1432            tags: None,
1433            required_resolver_tags: None,
1434            correlation_id: None,
1435            store_online: None,
1436            store_offline: None,
1437            required_output: None,
1438            run_asynchronously: None,
1439            num_shards: None,
1440            num_workers: None,
1441            resources: None,
1442            completion_deadline: None,
1443            max_retries: None,
1444            store_plan_stages: None,
1445            explain: None,
1446            planner_options: None,
1447            query_context: None,
1448            use_multiple_computers: None,
1449            spine_sql_query: None,
1450            query_name: None,
1451            query_name_version: None,
1452        };
1453
1454        let response = client.offline_query_raw(request).await.unwrap();
1455        assert!(!response.is_finished);
1456        assert_eq!(response.dataset_id.as_deref(), Some("ds-123"));
1457        assert_eq!(response.revisions.len(), 1);
1458
1459        offline_mock.assert_async().await;
1460    }
1461
1462    #[tokio::test]
1463    async fn test_offline_query_with_builder() {
1464        let mut server = mockito::Server::new_async().await;
1465
1466        server
1467            .mock("POST", "/v1/oauth/token")
1468            .with_status(200)
1469            .with_header("content-type", "application/json")
1470            .with_body(
1471                serde_json::json!({
1472                    "access_token": "jwt",
1473                    "expires_in": 3600,
1474                    "primary_environment": "env-1",
1475                    "engines": {"env-1": server.url()},
1476                    "grpc_engines": {}
1477                })
1478                .to_string(),
1479            )
1480            .create_async()
1481            .await;
1482
1483        let offline_mock = server
1484            .mock("POST", "/v4/offline_query")
1485            .with_status(200)
1486            .with_header("content-type", "application/json")
1487            .with_body(
1488                serde_json::json!({
1489                    "is_finished": false,
1490                    "dataset_id": "ds-456",
1491                    "revisions": [{
1492                        "revision_id": "rev-2",
1493                        "status": "pending"
1494                    }],
1495                    "errors": []
1496                })
1497                .to_string(),
1498            )
1499            .create_async()
1500            .await;
1501
1502        let client = ChalkClient::new()
1503            .client_id("id")
1504            .client_secret("secret")
1505            .api_server(&server.url())
1506            .environment("env-1")
1507            .build()
1508            .await
1509            .unwrap();
1510
1511        use crate::offline::OfflineQueryParams;
1512
1513        let response = client
1514            .offline_query(
1515                OfflineQueryParams::new()
1516                    .with_input("user.id", vec![serde_json::json!(1), serde_json::json!(2)])
1517                    .with_output("user.email")
1518                    .with_output("user.ltv")
1519                    .with_num_shards(4),
1520            )
1521            .await
1522            .unwrap();
1523
1524        assert!(!response.is_finished);
1525        assert_eq!(response.dataset_id.as_deref(), Some("ds-456"));
1526        offline_mock.assert_async().await;
1527    }
1528
1529    #[tokio::test]
1530    async fn test_wait_for_offline_query_success() {
1531        let mut server = mockito::Server::new_async().await;
1532
1533        server
1534            .mock("POST", "/v1/oauth/token")
1535            .with_status(200)
1536            .with_header("content-type", "application/json")
1537            .with_body(
1538                serde_json::json!({
1539                    "access_token": "jwt",
1540                    "expires_in": 3600,
1541                    "primary_environment": "env-1",
1542                    "engines": {"env-1": server.url()},
1543                    "grpc_engines": {}
1544                })
1545                .to_string(),
1546            )
1547            .create_async()
1548            .await;
1549
1550        server
1551            .mock("GET", "/v4/offline_query/rev-1/status")
1552            .with_status(200)
1553            .with_header("content-type", "application/json")
1554            .with_body(
1555                serde_json::json!({
1556                    "report": {
1557                        "status": "RUNNING"
1558                    }
1559                })
1560                .to_string(),
1561            )
1562            .create_async()
1563            .await;
1564
1565        server
1566            .mock("GET", "/v4/offline_query/rev-1/status")
1567            .with_status(200)
1568            .with_header("content-type", "application/json")
1569            .with_body(
1570                serde_json::json!({
1571                    "report": {
1572                        "status": "COMPLETED"
1573                    }
1574                })
1575                .to_string(),
1576            )
1577            .create_async()
1578            .await;
1579
1580        let client = ChalkClient::new()
1581            .client_id("id")
1582            .client_secret("secret")
1583            .api_server(&server.url())
1584            .environment("env-1")
1585            .build()
1586            .await
1587            .unwrap();
1588
1589        let response = OfflineQueryResponse {
1590            is_finished: false,
1591            version: None,
1592            dataset_id: Some("ds-123".into()),
1593            dataset_name: None,
1594            environment_id: None,
1595            revisions: vec![crate::types::DatasetRevision {
1596                revision_id: Some("rev-1".into()),
1597                creator_id: None,
1598                environment_id: None,
1599                outputs: vec![],
1600                status: Some("pending".into()),
1601                num_partitions: None,
1602                output_uris: None,
1603                created_at: None,
1604                started_at: None,
1605                terminated_at: None,
1606                dashboard_url: None,
1607                dataset_name: None,
1608                dataset_id: None,
1609                branch: None,
1610            }],
1611            errors: vec![],
1612        };
1613
1614        let result = client
1615            .wait_for_offline_query(&response, Some(Duration::from_secs(5)))
1616            .await;
1617        assert!(result.is_ok());
1618    }
1619
1620    #[tokio::test]
1621    async fn test_wait_for_offline_query_failure() {
1622        let mut server = mockito::Server::new_async().await;
1623
1624        server
1625            .mock("POST", "/v1/oauth/token")
1626            .with_status(200)
1627            .with_header("content-type", "application/json")
1628            .with_body(
1629                serde_json::json!({
1630                    "access_token": "jwt",
1631                    "expires_in": 3600,
1632                    "primary_environment": "env-1",
1633                    "engines": {"env-1": server.url()},
1634                    "grpc_engines": {}
1635                })
1636                .to_string(),
1637            )
1638            .create_async()
1639            .await;
1640
1641        server
1642            .mock("GET", "/v4/offline_query/rev-1/status")
1643            .with_status(200)
1644            .with_header("content-type", "application/json")
1645            .with_body(
1646                serde_json::json!({
1647                    "report": {
1648                        "status": "FAILED",
1649                        "all_errors": [{
1650                            "code": "INTERNAL_ERROR",
1651                            "category": "REQUEST",
1652                            "message": "job failed due to OOM"
1653                        }]
1654                    }
1655                })
1656                .to_string(),
1657            )
1658            .create_async()
1659            .await;
1660
1661        let client = ChalkClient::new()
1662            .client_id("id")
1663            .client_secret("secret")
1664            .api_server(&server.url())
1665            .environment("env-1")
1666            .build()
1667            .await
1668            .unwrap();
1669
1670        let response = OfflineQueryResponse {
1671            is_finished: false,
1672            version: None,
1673            dataset_id: None,
1674            dataset_name: None,
1675            environment_id: None,
1676            revisions: vec![crate::types::DatasetRevision {
1677                revision_id: Some("rev-1".into()),
1678                creator_id: None,
1679                environment_id: None,
1680                outputs: vec![],
1681                status: None,
1682                num_partitions: None,
1683                output_uris: None,
1684                created_at: None,
1685                started_at: None,
1686                terminated_at: None,
1687                dashboard_url: None,
1688                dataset_name: None,
1689                dataset_id: None,
1690                branch: None,
1691            }],
1692            errors: vec![],
1693        };
1694
1695        let result = client
1696            .wait_for_offline_query(&response, Some(Duration::from_secs(5)))
1697            .await;
1698        assert!(result.is_err());
1699        let err = result.unwrap_err().to_string();
1700        assert!(err.contains("OOM"));
1701    }
1702
1703    #[tokio::test]
1704    async fn test_get_offline_query_download_urls() {
1705        let mut server = mockito::Server::new_async().await;
1706
1707        server
1708            .mock("POST", "/v1/oauth/token")
1709            .with_status(200)
1710            .with_header("content-type", "application/json")
1711            .with_body(
1712                serde_json::json!({
1713                    "access_token": "jwt",
1714                    "expires_in": 3600,
1715                    "primary_environment": "env-1",
1716                    "engines": {"env-1": server.url()},
1717                    "grpc_engines": {}
1718                })
1719                .to_string(),
1720            )
1721            .create_async()
1722            .await;
1723
1724        server
1725            .mock("GET", "/v2/offline_query/rev-1")
1726            .with_status(200)
1727            .with_header("content-type", "application/json")
1728            .with_body(
1729                serde_json::json!({
1730                    "is_finished": false,
1731                    "urls": [],
1732                    "errors": []
1733                })
1734                .to_string(),
1735            )
1736            .create_async()
1737            .await;
1738
1739        server
1740            .mock("GET", "/v2/offline_query/rev-1")
1741            .with_status(200)
1742            .with_header("content-type", "application/json")
1743            .with_body(
1744                serde_json::json!({
1745                    "is_finished": true,
1746                    "urls": [
1747                        "https://storage.example.com/results/part-0.parquet",
1748                        "https://storage.example.com/results/part-1.parquet"
1749                    ],
1750                    "errors": []
1751                })
1752                .to_string(),
1753            )
1754            .create_async()
1755            .await;
1756
1757        let client = ChalkClient::new()
1758            .client_id("id")
1759            .client_secret("secret")
1760            .api_server(&server.url())
1761            .environment("env-1")
1762            .build()
1763            .await
1764            .unwrap();
1765
1766        let response = OfflineQueryResponse {
1767            is_finished: false,
1768            version: None,
1769            dataset_id: None,
1770            dataset_name: None,
1771            environment_id: None,
1772            revisions: vec![crate::types::DatasetRevision {
1773                revision_id: Some("rev-1".into()),
1774                creator_id: None,
1775                environment_id: None,
1776                outputs: vec![],
1777                status: None,
1778                num_partitions: None,
1779                output_uris: None,
1780                created_at: None,
1781                started_at: None,
1782                terminated_at: None,
1783                dashboard_url: None,
1784                dataset_name: None,
1785                dataset_id: None,
1786                branch: None,
1787            }],
1788            errors: vec![],
1789        };
1790
1791        let urls = client
1792            .get_offline_query_download_urls(&response, Some(Duration::from_secs(5)))
1793            .await
1794            .unwrap();
1795
1796        assert_eq!(urls.len(), 2);
1797        assert!(urls[0].contains("part-0.parquet"));
1798        assert!(urls[1].contains("part-1.parquet"));
1799    }
1800
1801    #[tokio::test]
1802    async fn test_wait_for_offline_query_timeout() {
1803        let mut server = mockito::Server::new_async().await;
1804
1805        server
1806            .mock("POST", "/v1/oauth/token")
1807            .with_status(200)
1808            .with_header("content-type", "application/json")
1809            .with_body(
1810                serde_json::json!({
1811                    "access_token": "jwt",
1812                    "expires_in": 3600,
1813                    "primary_environment": "env-1",
1814                    "engines": {"env-1": server.url()},
1815                    "grpc_engines": {}
1816                })
1817                .to_string(),
1818            )
1819            .create_async()
1820            .await;
1821
1822        server
1823            .mock("GET", "/v4/offline_query/rev-1/status")
1824            .with_status(200)
1825            .with_header("content-type", "application/json")
1826            .with_body(
1827                serde_json::json!({
1828                    "report": {
1829                        "status": "RUNNING"
1830                    }
1831                })
1832                .to_string(),
1833            )
1834            .expect_at_least(1)
1835            .create_async()
1836            .await;
1837
1838        let client = ChalkClient::new()
1839            .client_id("id")
1840            .client_secret("secret")
1841            .api_server(&server.url())
1842            .environment("env-1")
1843            .build()
1844            .await
1845            .unwrap();
1846
1847        let response = OfflineQueryResponse {
1848            is_finished: false,
1849            version: None,
1850            dataset_id: None,
1851            dataset_name: None,
1852            environment_id: None,
1853            revisions: vec![crate::types::DatasetRevision {
1854                revision_id: Some("rev-1".into()),
1855                creator_id: None,
1856                environment_id: None,
1857                outputs: vec![],
1858                status: None,
1859                num_partitions: None,
1860                output_uris: None,
1861                created_at: None,
1862                started_at: None,
1863                terminated_at: None,
1864                dashboard_url: None,
1865                dataset_name: None,
1866                dataset_id: None,
1867                branch: None,
1868            }],
1869            errors: vec![],
1870        };
1871
1872        let result = client
1873            .wait_for_offline_query(&response, Some(Duration::from_millis(500)))
1874            .await;
1875        assert!(result.is_err());
1876        let err = result.unwrap_err().to_string();
1877        assert!(err.contains("timed out"));
1878    }
1879
1880    #[test]
1881    fn test_build_byte_base_model() {
1882        let json_attrs = serde_json::json!({
1883            "features": ["user.id", "user.age"],
1884            "table_compression": "uncompressed",
1885        });
1886        let fake_arrow = b"ARROW1fake_data_here";
1887
1888        let body = build_byte_base_model(&json_attrs, &[("table_bytes", fake_arrow.as_slice())])
1889            .unwrap();
1890
1891        let mut pos = 0;
1892
1893        assert_eq!(
1894            &body[pos..pos + BYTEMODEL_MAGIC_STR.len()],
1895            BYTEMODEL_MAGIC_STR
1896        );
1897        pos += BYTEMODEL_MAGIC_STR.len();
1898
1899        let json_attrs_len =
1900            u64::from_be_bytes(body[pos..pos + 8].try_into().unwrap()) as usize;
1901        pos += 8;
1902        let json_attrs_parsed: Value =
1903            serde_json::from_slice(&body[pos..pos + json_attrs_len]).unwrap();
1904        assert_eq!(json_attrs_parsed["features"][0], "user.id");
1905        assert_eq!(json_attrs_parsed["table_compression"], "uncompressed");
1906        pos += json_attrs_len;
1907
1908        let pydantic_len =
1909            u64::from_be_bytes(body[pos..pos + 8].try_into().unwrap()) as usize;
1910        pos += 8;
1911        let pydantic: Value =
1912            serde_json::from_slice(&body[pos..pos + pydantic_len]).unwrap();
1913        assert_eq!(pydantic, serde_json::json!({}));
1914        pos += pydantic_len;
1915
1916        let byte_map_len =
1917            u64::from_be_bytes(body[pos..pos + 8].try_into().unwrap()) as usize;
1918        pos += 8;
1919        let byte_map: Value =
1920            serde_json::from_slice(&body[pos..pos + byte_map_len]).unwrap();
1921        assert_eq!(byte_map["table_bytes"], fake_arrow.len() as u64);
1922        pos += byte_map_len;
1923
1924        assert_eq!(&body[pos..pos + fake_arrow.len()], fake_arrow);
1925        pos += fake_arrow.len();
1926
1927        let ser_len =
1928            u64::from_be_bytes(body[pos..pos + 8].try_into().unwrap()) as usize;
1929        pos += 8;
1930        let ser: Value = serde_json::from_slice(&body[pos..pos + ser_len]).unwrap();
1931        assert_eq!(ser, serde_json::json!({}));
1932        pos += ser_len;
1933
1934        assert_eq!(pos, body.len());
1935    }
1936
1937    #[tokio::test]
1938    async fn test_upload_features() {
1939        let mut server = mockito::Server::new_async().await;
1940
1941        server
1942            .mock("POST", "/v1/oauth/token")
1943            .with_status(200)
1944            .with_header("content-type", "application/json")
1945            .with_body(
1946                serde_json::json!({
1947                    "access_token": "jwt",
1948                    "expires_in": 3600,
1949                    "primary_environment": "env-1",
1950                    "engines": {"env-1": server.url()},
1951                    "grpc_engines": {}
1952                })
1953                .to_string(),
1954            )
1955            .create_async()
1956            .await;
1957
1958        let upload_mock = server
1959            .mock("POST", "/v1/upload_features/multi")
1960            .match_header("Authorization", "Bearer jwt")
1961            .match_header("Content-Type", "application/octet-stream")
1962            .with_status(200)
1963            .with_header("content-type", "application/json")
1964            .with_body(
1965                serde_json::json!({
1966                    "operation_id": "op-abc-123",
1967                    "errors": []
1968                })
1969                .to_string(),
1970            )
1971            .create_async()
1972            .await;
1973
1974        let client = ChalkClient::new()
1975            .client_id("id")
1976            .client_secret("secret")
1977            .api_server(&server.url())
1978            .environment("env-1")
1979            .build()
1980            .await
1981            .unwrap();
1982
1983        let schema = Arc::new(Schema::new(vec![
1984            Field::new("user.id", DataType::Int32, false),
1985            Field::new("user.age", DataType::Int32, true),
1986        ]));
1987        let batch = RecordBatch::try_new(
1988            schema,
1989            vec![
1990                Arc::new(Int32Array::from(vec![1, 2, 3])),
1991                Arc::new(Int32Array::from(vec![25, 30, 22])),
1992            ],
1993        )
1994        .unwrap();
1995
1996        let result = client.upload_features(&batch).await.unwrap();
1997        assert_eq!(result.operation_id.as_deref(), Some("op-abc-123"));
1998        assert!(result.errors.is_empty());
1999
2000        upload_mock.assert_async().await;
2001    }
2002
2003    #[tokio::test]
2004    async fn test_upload_features_map() {
2005        let mut server = mockito::Server::new_async().await;
2006
2007        server
2008            .mock("POST", "/v1/oauth/token")
2009            .with_status(200)
2010            .with_header("content-type", "application/json")
2011            .with_body(
2012                serde_json::json!({
2013                    "access_token": "jwt",
2014                    "expires_in": 3600,
2015                    "primary_environment": "env-1",
2016                    "engines": {"env-1": server.url()},
2017                    "grpc_engines": {}
2018                })
2019                .to_string(),
2020            )
2021            .create_async()
2022            .await;
2023
2024        let upload_mock = server
2025            .mock("POST", "/v1/upload_features/multi")
2026            .with_status(200)
2027            .with_header("content-type", "application/json")
2028            .with_body(
2029                serde_json::json!({
2030                    "operation_id": "op-map-456",
2031                    "errors": []
2032                })
2033                .to_string(),
2034            )
2035            .create_async()
2036            .await;
2037
2038        let client = ChalkClient::new()
2039            .client_id("id")
2040            .client_secret("secret")
2041            .api_server(&server.url())
2042            .environment("env-1")
2043            .build()
2044            .await
2045            .unwrap();
2046
2047        let inputs = HashMap::from([
2048            (
2049                "user.id".to_string(),
2050                vec![serde_json::json!(1), serde_json::json!(2)],
2051            ),
2052            (
2053                "user.name".to_string(),
2054                vec![serde_json::json!("Alice"), serde_json::json!("Bob")],
2055            ),
2056        ]);
2057
2058        let result = client.upload_features_map(inputs).await.unwrap();
2059        assert_eq!(result.operation_id.as_deref(), Some("op-map-456"));
2060
2061        upload_mock.assert_async().await;
2062    }
2063
2064    #[tokio::test]
2065    async fn test_upload_features_map_empty_inputs() {
2066        let mut server = mockito::Server::new_async().await;
2067
2068        server
2069            .mock("POST", "/v1/oauth/token")
2070            .with_status(200)
2071            .with_header("content-type", "application/json")
2072            .with_body(
2073                serde_json::json!({
2074                    "access_token": "jwt",
2075                    "expires_in": 3600,
2076                    "primary_environment": "env-1",
2077                    "engines": {"env-1": server.url()},
2078                    "grpc_engines": {}
2079                })
2080                .to_string(),
2081            )
2082            .create_async()
2083            .await;
2084
2085        let client = ChalkClient::new()
2086            .client_id("id")
2087            .client_secret("secret")
2088            .api_server(&server.url())
2089            .environment("env-1")
2090            .build()
2091            .await
2092            .unwrap();
2093
2094        let result = client.upload_features_map(HashMap::new()).await;
2095        assert!(result.is_err());
2096        let err = result.unwrap_err().to_string();
2097        assert!(err.contains("at least one feature"));
2098    }
2099}