1use std::collections::HashMap;
30use std::time::Duration;
31
32use arrow::ipc::writer::FileWriter;
33use arrow::record_batch::RecordBatch;
34use serde::Serialize;
35use serde_json::Value;
36
37use crate::auth::TokenManager;
38use crate::config::{ChalkClientConfig, ChalkClientConfigBuilder, ensure_scheme};
39use crate::error::{ChalkClientError, Result};
40use crate::offline::OfflineQueryParams;
41use crate::types::{
42 FeatureEncodingOptions, GetOfflineQueryJobResponse, GetOfflineQueryStatusResponse,
43 OfflineQueryRequest, OfflineQueryResponse, OnlineQueryContext, OnlineQueryRequest,
44 OnlineQueryResponse, QueryOptions, UploadFeaturesResult,
45};
46
47const USER_AGENT: &str = "chalk-rust/0.1.0";
49
50const MULTI_QUERY_MAGIC_STR: &[u8] = b"chal1";
52
53const BYTEMODEL_MAGIC_STR: &[u8] = b"CHALK_BYTE_TRANSMISSION";
55
56pub struct ChalkClient {
62 config: ChalkClientConfig,
64
65 token_manager: TokenManager,
67
68 http_client: reqwest::Client,
70
71 query_server: String,
73
74 environment_id: String,
76}
77
78pub struct ChalkClientBuilder {
84 config_builder: ChalkClientConfigBuilder,
85}
86
87impl ChalkClient {
88 #[allow(clippy::new_ret_no_self)]
90 pub fn new() -> ChalkClientBuilder {
91 ChalkClientBuilder {
92 config_builder: ChalkClientConfigBuilder::new(),
93 }
94 }
95}
96
97impl ChalkClientBuilder {
98 pub fn client_id(mut self, id: impl Into<String>) -> Self {
100 self.config_builder = self.config_builder.client_id(id);
101 self
102 }
103
104 pub fn client_secret(mut self, secret: impl Into<String>) -> Self {
106 self.config_builder = self.config_builder.client_secret(secret);
107 self
108 }
109
110 pub fn api_server(mut self, url: impl Into<String>) -> Self {
112 self.config_builder = self.config_builder.api_server(url);
113 self
114 }
115
116 pub fn environment(mut self, env: impl Into<String>) -> Self {
118 self.config_builder = self.config_builder.environment(env);
119 self
120 }
121
122 pub fn branch_id(mut self, id: impl Into<String>) -> Self {
124 self.config_builder = self.config_builder.branch_id(id);
125 self
126 }
127
128 pub fn deployment_tag(mut self, tag: impl Into<String>) -> Self {
130 self.config_builder = self.config_builder.deployment_tag(tag);
131 self
132 }
133
134 pub fn query_server(mut self, url: impl Into<String>) -> Self {
136 self.config_builder = self.config_builder.query_server(url);
137 self
138 }
139
140 pub async fn build(self) -> Result<ChalkClient> {
145 let config = self.config_builder.build()?;
146
147 let token_manager = TokenManager::new(config.clone());
148 let token = token_manager.get_token().await?;
149
150 let environment_id = config
151 .environment
152 .clone()
153 .or(token.primary_environment.clone())
154 .ok_or_else(|| {
155 ChalkClientError::Config(
156 "no environment specified and token has no primary_environment".into(),
157 )
158 })?;
159
160 let query_server = ensure_scheme(
161 config
162 .query_server
163 .clone()
164 .or_else(|| token.engines.get(&environment_id).cloned())
165 .unwrap_or_else(|| config.api_server.clone()),
166 );
167
168 tracing::info!(
169 environment = %environment_id,
170 query_server = %query_server,
171 "ChalkClient initialized"
172 );
173
174 Ok(ChalkClient {
175 config,
176 token_manager,
177 http_client: reqwest::Client::new(),
178 query_server,
179 environment_id,
180 })
181 }
182}
183
184impl ChalkClient {
189 pub async fn query(
197 &self,
198 inputs: HashMap<String, Value>,
199 outputs: Vec<String>,
200 options: QueryOptions,
201 ) -> Result<OnlineQueryResponse> {
202 let url = format!("{}/v1/query/online", self.engine_url());
203
204 let body = OnlineQueryRequest {
205 inputs,
206 outputs,
207 context: options.context,
208 staleness: options.staleness,
209 include_meta: options.include_meta,
210 query_name: options.query_name,
211 correlation_id: options.correlation_id,
212 query_context: options.query_context,
213 meta: options.meta,
214 query_name_version: options.query_name_version,
215 now: options.now,
216 explain: options.explain,
217 store_plan_stages: options.store_plan_stages,
218 encoding_options: options.encoding_options,
219 branch_id: options.branch_id.or(self.config.branch_id.clone()),
220 };
221
222 let resp = self
223 .send_json_request(reqwest::Method::POST, &url, &body)
224 .await?;
225
226 let status = resp.status();
227 let body_text = resp.text().await?;
228
229 if !status.is_success() {
230 return Err(ChalkClientError::Api {
231 status: status.as_u16(),
232 message: body_text,
233 });
234 }
235
236 let response: OnlineQueryResponse = serde_json::from_str(&body_text)?;
237
238 if !response.errors.is_empty() {
239 tracing::warn!(
240 error_count = response.errors.len(),
241 "query returned server errors"
242 );
243 }
244
245 Ok(response)
246 }
247
248 pub async fn query_bulk(
254 &self,
255 inputs: &RecordBatch,
256 outputs: Vec<String>,
257 options: QueryOptions,
258 ) -> Result<BulkQueryResult> {
259 let url = format!("{}/v1/query/feather", self.engine_url());
260
261 let header = FeatherRequestHeader {
262 outputs: outputs.clone(),
263 expression_outputs: vec![],
264 now: None,
265 staleness: options.staleness,
266 context: options.context,
267 include_meta: options.include_meta.unwrap_or(true),
268 explain: options.explain.unwrap_or(false),
269 correlation_id: options.correlation_id,
270 query_name: options.query_name,
271 query_name_version: options.query_name_version,
272 deployment_id: None,
273 branch_id: options.branch_id.or(self.config.branch_id.clone()),
274 meta: options.meta,
275 store_plan_stages: options.store_plan_stages.or(Some(false)),
276 query_context: options.query_context,
277 encoding_options: options
278 .encoding_options
279 .unwrap_or(FeatureEncodingOptions {
280 encode_structs_as_objects: None,
281 }),
282 planner_options: options.planner_options,
283 value_metrics_tag_by_features: vec![],
284 overlay_graph: None,
285 };
286
287 let feather_bytes = serialize_record_batch_to_feather(inputs)?;
288
289 let request_body = build_feather_request_body(&header, &feather_bytes)?;
290
291 let token = self.token_manager.get_token().await?;
292
293 let deployment_type = if self.config.branch_id.is_some() {
294 "branch"
295 } else {
296 "engine"
297 };
298
299 let mut request = self
300 .http_client
301 .post(&url)
302 .header("Authorization", format!("Bearer {}", token.access_token))
303 .header("User-Agent", USER_AGENT)
304 .header("Content-Type", "application/octet-stream")
305 .header("Accept", "application/octet-stream")
306 .header("X-Chalk-Client-Id", &self.config.client_id)
307 .header("X-Chalk-Env-Id", &self.environment_id)
308 .header("X-Chalk-Deployment-Type", deployment_type)
309 .header("X-Chalk-Features-Versioned", "true");
310
311 if let Some(ref branch) = self.config.branch_id {
312 request = request.header("X-Chalk-Branch-Id", branch.as_str());
313 }
314 if let Some(ref tag) = self.config.deployment_tag {
315 request = request.header("X-Chalk-Deployment-Tag", tag);
316 }
317
318 let resp = request.body(request_body).send().await?;
319
320 let status = resp.status();
321 if !status.is_success() {
322 let body = resp.text().await.unwrap_or_default();
323 return Err(ChalkClientError::Api {
324 status: status.as_u16(),
325 message: body,
326 });
327 }
328
329 let response_bytes = resp.bytes().await?;
330 parse_bulk_query_response(&response_bytes)
331 }
332
333 pub async fn offline_query(
349 &self,
350 params: OfflineQueryParams,
351 ) -> Result<OfflineQueryResponse> {
352 let request = params.build()?;
353 self.offline_query_raw(request).await
354 }
355
356 pub async fn offline_query_raw(
358 &self,
359 request: OfflineQueryRequest,
360 ) -> Result<OfflineQueryResponse> {
361 let url = format!("{}/v4/offline_query", self.config.api_server);
362
363 let resp = self
364 .send_json_request(reqwest::Method::POST, &url, &request)
365 .await?;
366
367 let status = resp.status();
368 let body_text = resp.text().await?;
369
370 if !status.is_success() {
371 return Err(ChalkClientError::Api {
372 status: status.as_u16(),
373 message: body_text,
374 });
375 }
376
377 let response: OfflineQueryResponse = serde_json::from_str(&body_text)?;
378 Ok(response)
379 }
380
381 pub async fn get_offline_query_status(
383 &self,
384 job_id: &str,
385 ) -> Result<GetOfflineQueryStatusResponse> {
386 let url = format!(
387 "{}/v4/offline_query/{}/status",
388 self.config.api_server, job_id
389 );
390
391 let resp = self
392 .send_get_request(&url)
393 .await?;
394
395 let status = resp.status();
396 let body_text = resp.text().await?;
397
398 if !status.is_success() {
399 return Err(ChalkClientError::Api {
400 status: status.as_u16(),
401 message: body_text,
402 });
403 }
404
405 let response: GetOfflineQueryStatusResponse = serde_json::from_str(&body_text)?;
406 Ok(response)
407 }
408
409 pub async fn wait_for_offline_query(
414 &self,
415 response: &OfflineQueryResponse,
416 timeout: Option<Duration>,
417 ) -> Result<()> {
418 let revision = response
419 .revisions
420 .last()
421 .and_then(|r| r.revision_id.as_deref())
422 .ok_or_else(|| {
423 ChalkClientError::Config("offline query response has no revision ID".into())
424 })?;
425
426 let poll_fut = async {
427 loop {
428 let status_resp = self.get_offline_query_status(revision).await?;
429 let report = match status_resp.report {
430 Some(r) => r,
431 None => {
432 tokio::time::sleep(Duration::from_secs(1)).await;
433 continue;
434 }
435 };
436 let status = report.status.as_deref().unwrap_or("UNKNOWN");
437
438 match status {
439 "COMPLETED" => return Ok(()),
440 "FAILED" => {
441 let errors = report.all_errors;
442 if errors.is_empty() {
443 if let Some(err) = report.error {
444 return Err(ChalkClientError::ServerErrors(vec![err]));
445 }
446 return Err(ChalkClientError::Api {
447 status: 0,
448 message: "offline query failed with no error details".into(),
449 });
450 }
451 return Err(ChalkClientError::ServerErrors(errors));
452 }
453 _ => {
454 tokio::time::sleep(Duration::from_secs(1)).await;
455 }
456 }
457 }
458 };
459
460 if let Some(timeout_dur) = timeout {
461 tokio::time::timeout(timeout_dur, poll_fut)
462 .await
463 .map_err(|_| {
464 ChalkClientError::Api {
465 status: 0,
466 message: format!(
467 "timed out waiting for offline query after {:?}",
468 timeout_dur
469 ),
470 }
471 })?
472 } else {
473 poll_fut.await
474 }
475 }
476
477 pub async fn get_offline_query_download_urls(
479 &self,
480 response: &OfflineQueryResponse,
481 timeout: Option<Duration>,
482 ) -> Result<Vec<String>> {
483 let revision_id = response
484 .revisions
485 .last()
486 .and_then(|r| r.revision_id.as_deref())
487 .ok_or_else(|| {
488 ChalkClientError::Config("offline query response has no revision ID".into())
489 })?;
490
491 let poll_fut = async {
492 loop {
493 let url = format!(
494 "{}/v2/offline_query/{}",
495 self.config.api_server, revision_id
496 );
497
498 let resp = self.send_get_request(&url).await?;
499 let status = resp.status();
500 let body_text = resp.text().await?;
501
502 if !status.is_success() {
503 return Err(ChalkClientError::Api {
504 status: status.as_u16(),
505 message: body_text,
506 });
507 }
508
509 let job_resp: GetOfflineQueryJobResponse = serde_json::from_str(&body_text)?;
510
511 if job_resp.is_finished {
512 if !job_resp.errors.is_empty() {
513 return Err(ChalkClientError::ServerErrors(job_resp.errors));
514 }
515 return Ok(job_resp.urls);
516 }
517
518 tokio::time::sleep(Duration::from_millis(500)).await;
519 }
520 };
521
522 if let Some(timeout_dur) = timeout {
523 tokio::time::timeout(timeout_dur, poll_fut)
524 .await
525 .map_err(|_| {
526 ChalkClientError::Api {
527 status: 0,
528 message: format!(
529 "timed out waiting for download URLs after {:?}",
530 timeout_dur
531 ),
532 }
533 })?
534 } else {
535 poll_fut.await
536 }
537 }
538
539 pub async fn upload_features(
546 &self,
547 features: &RecordBatch,
548 ) -> Result<UploadFeaturesResult> {
549 let url = format!("{}/v1/upload_features/multi", self.engine_url());
550
551 let feature_names: Vec<String> = features
552 .schema()
553 .fields()
554 .iter()
555 .map(|f| f.name().clone())
556 .collect();
557
558 let feather_bytes = serialize_record_batch_to_feather(features)?;
559
560 let json_attrs = serde_json::json!({
561 "features": feature_names,
562 "table_compression": "uncompressed",
563 });
564 let body = build_byte_base_model(&json_attrs, &[("table_bytes", &feather_bytes)])?;
565
566 let token = self.token_manager.get_token().await?;
567
568 let deployment_type = if self.config.branch_id.is_some() {
569 "branch"
570 } else {
571 "engine"
572 };
573
574 let mut request = self
575 .http_client
576 .post(&url)
577 .header("Authorization", format!("Bearer {}", token.access_token))
578 .header("User-Agent", USER_AGENT)
579 .header("Content-Type", "application/octet-stream")
580 .header("Accept", "application/json")
581 .header("X-Chalk-Client-Id", &self.config.client_id)
582 .header("X-Chalk-Env-Id", &self.environment_id)
583 .header("X-Chalk-Deployment-Type", deployment_type)
584 .header("X-Chalk-Features-Versioned", "true");
585
586 if let Some(ref branch) = self.config.branch_id {
587 request = request.header("X-Chalk-Branch-Id", branch.as_str());
588 }
589 if let Some(ref tag) = self.config.deployment_tag {
590 request = request.header("X-Chalk-Deployment-Tag", tag);
591 }
592
593 let resp = request.body(body).send().await?;
594
595 let status = resp.status();
596 let body_text = resp.text().await?;
597
598 if !status.is_success() {
599 return Err(ChalkClientError::Api {
600 status: status.as_u16(),
601 message: body_text,
602 });
603 }
604
605 let result: UploadFeaturesResult = serde_json::from_str(&body_text)?;
606
607 if !result.errors.is_empty() {
608 tracing::warn!(
609 error_count = result.errors.len(),
610 "upload_features returned server errors"
611 );
612 }
613
614 Ok(result)
615 }
616
617 pub async fn upload_features_map(
634 &self,
635 inputs: HashMap<String, Vec<Value>>,
636 ) -> Result<UploadFeaturesResult> {
637 use arrow::array::StringArray;
638 use arrow::datatypes::{DataType, Field, Schema};
639 use std::sync::Arc;
640
641 if inputs.is_empty() {
642 return Err(ChalkClientError::Config(
643 "upload_features_map requires at least one feature".into(),
644 ));
645 }
646
647 let mut feature_names: Vec<String> = inputs.keys().cloned().collect();
648 feature_names.sort();
649
650 let num_rows = inputs[&feature_names[0]].len();
651
652 let fields: Vec<Field> = feature_names
653 .iter()
654 .map(|name| Field::new(name, DataType::Utf8, true))
655 .collect();
656 let schema = Arc::new(Schema::new(fields));
657
658 let columns: Vec<Arc<dyn arrow::array::Array>> = feature_names
659 .iter()
660 .map(|name| {
661 let values = &inputs[name];
662 let strings: Vec<Option<String>> = values
663 .iter()
664 .map(|v| match v {
665 Value::Null => None,
666 Value::String(s) => Some(s.clone()),
667 other => Some(other.to_string()),
668 })
669 .collect();
670 Arc::new(StringArray::from(strings)) as Arc<dyn arrow::array::Array>
671 })
672 .collect();
673
674 let batch = RecordBatch::try_new(schema, columns).map_err(|e| {
675 ChalkClientError::Arrow(e)
676 })?;
677
678 if batch.num_rows() != num_rows {
679 return Err(ChalkClientError::Config(
680 "all input arrays must be the same length".into(),
681 ));
682 }
683
684 self.upload_features(&batch).await
685 }
686
687 pub fn environment_id(&self) -> &str {
689 &self.environment_id
690 }
691
692 pub fn query_server(&self) -> &str {
694 &self.query_server
695 }
696
697 fn engine_url(&self) -> &str {
702 if self.config.branch_id.is_some() {
703 &self.config.api_server
704 } else {
705 &self.query_server
706 }
707 }
708
709 async fn send_json_request<T: serde::Serialize>(
710 &self,
711 method: reqwest::Method,
712 url: &str,
713 body: &T,
714 ) -> Result<reqwest::Response> {
715 let token = self.token_manager.get_token().await?;
716
717 let deployment_type = if self.config.branch_id.is_some() {
718 "branch"
719 } else {
720 "engine"
721 };
722
723 let mut request = self
724 .http_client
725 .request(method, url)
726 .header("Authorization", format!("Bearer {}", token.access_token))
727 .header("Content-Type", "application/json")
728 .header("Accept", "application/json")
729 .header("User-Agent", USER_AGENT)
730 .header("X-Chalk-Client-Id", &self.config.client_id)
731 .header("X-Chalk-Env-Id", &self.environment_id)
732 .header("X-Chalk-Deployment-Type", deployment_type)
733 .header("X-Chalk-Features-Versioned", "true");
734
735 if let Some(ref branch) = self.config.branch_id {
736 request = request.header("X-Chalk-Branch-Id", branch.as_str());
737 }
738 if let Some(ref tag) = self.config.deployment_tag {
739 request = request.header("X-Chalk-Deployment-Tag", tag);
740 }
741
742 let resp = request.json(body).send().await?;
743 Ok(resp)
744 }
745
746 async fn send_get_request(&self, url: &str) -> Result<reqwest::Response> {
747 let token = self.token_manager.get_token().await?;
748
749 let deployment_type = if self.config.branch_id.is_some() {
750 "branch"
751 } else {
752 "engine"
753 };
754
755 let mut request = self
756 .http_client
757 .get(url)
758 .header("Authorization", format!("Bearer {}", token.access_token))
759 .header("Accept", "application/json")
760 .header("User-Agent", USER_AGENT)
761 .header("X-Chalk-Client-Id", &self.config.client_id)
762 .header("X-Chalk-Env-Id", &self.environment_id)
763 .header("X-Chalk-Deployment-Type", deployment_type)
764 .header("X-Chalk-Features-Versioned", "true");
765
766 if let Some(ref branch) = self.config.branch_id {
767 request = request.header("X-Chalk-Branch-Id", branch.as_str());
768 }
769 if let Some(ref tag) = self.config.deployment_tag {
770 request = request.header("X-Chalk-Deployment-Tag", tag);
771 }
772
773 let resp = request.send().await?;
774 Ok(resp)
775 }
776}
777
778#[derive(Debug, Serialize)]
783struct FeatherRequestHeader {
784 outputs: Vec<String>,
785 #[serde(default)]
786 expression_outputs: Vec<String>,
787 #[serde(skip_serializing_if = "Option::is_none")]
788 now: Option<Vec<String>>,
789 #[serde(skip_serializing_if = "Option::is_none")]
790 staleness: Option<HashMap<String, String>>,
791 #[serde(skip_serializing_if = "Option::is_none")]
792 context: Option<OnlineQueryContext>,
793 include_meta: bool,
794 explain: bool,
795 #[serde(skip_serializing_if = "Option::is_none")]
796 correlation_id: Option<String>,
797 #[serde(skip_serializing_if = "Option::is_none")]
798 query_name: Option<String>,
799 #[serde(skip_serializing_if = "Option::is_none")]
800 query_name_version: Option<String>,
801 #[serde(skip_serializing_if = "Option::is_none")]
802 deployment_id: Option<String>,
803 #[serde(skip_serializing_if = "Option::is_none")]
804 branch_id: Option<String>,
805 #[serde(skip_serializing_if = "Option::is_none")]
806 meta: Option<HashMap<String, String>>,
807 #[serde(skip_serializing_if = "Option::is_none")]
808 store_plan_stages: Option<bool>,
809 #[serde(skip_serializing_if = "Option::is_none")]
810 query_context: Option<HashMap<String, Value>>,
811 encoding_options: FeatureEncodingOptions,
812 #[serde(skip_serializing_if = "Option::is_none")]
813 planner_options: Option<HashMap<String, Value>>,
814 #[serde(default)]
815 value_metrics_tag_by_features: Vec<String>,
816 #[serde(skip_serializing_if = "Option::is_none")]
817 overlay_graph: Option<String>,
818}
819
820#[derive(Debug)]
826pub struct BulkQueryResult {
827 pub scalar_data: Vec<u8>,
829
830 pub has_data: bool,
832
833 pub meta: Option<String>,
835
836 pub errors: Vec<String>,
838}
839
840fn build_feather_request_body(header: &FeatherRequestHeader, feather_bytes: &[u8]) -> Result<Vec<u8>> {
845 let header_json = serde_json::to_string(header)?;
846 let header_bytes = header_json.as_bytes();
847
848 let total_size = 5 + 8 + header_bytes.len() + 8 + feather_bytes.len();
849 let mut buf = Vec::with_capacity(total_size);
850
851 buf.extend_from_slice(MULTI_QUERY_MAGIC_STR);
852
853 buf.extend_from_slice(&(header_bytes.len() as u64).to_be_bytes());
854 buf.extend_from_slice(header_bytes);
855
856 buf.extend_from_slice(&(feather_bytes.len() as u64).to_be_bytes());
857 buf.extend_from_slice(feather_bytes);
858
859 Ok(buf)
860}
861
862fn parse_bulk_query_response(data: &[u8]) -> Result<BulkQueryResult> {
867 let mut pos: usize = 0;
868
869 pos = consume_magic(data, pos)?;
870
871 let (new_pos, _attrs_json) = read_length_prefixed_json(data, pos)?;
872 pos = new_pos;
873
874 let (new_pos, _pydantic_json) = read_length_prefixed_json(data, pos)?;
875 pos = new_pos;
876
877 let (new_pos, byte_offset_map) = read_length_prefixed_json(data, pos)?;
878 pos = new_pos;
879 pos = skip_byte_data(data, pos, &byte_offset_map)?;
880
881 let (new_pos, serializable_offset_map) = read_length_prefixed_json(data, pos)?;
882 pos = new_pos;
883
884 let query_results_len = serializable_offset_map
885 .get("query_results_bytes")
886 .and_then(|v| v.as_u64())
887 .ok_or_else(|| ChalkClientError::Api {
888 status: 0,
889 message: format!(
890 "missing query_results_bytes in serializable_attrs (got: {})",
891 serializable_offset_map
892 ),
893 })? as usize;
894
895 if pos + query_results_len > data.len() {
896 return Err(ChalkClientError::Api {
897 status: 0,
898 message: "response truncated: query_results_bytes extends beyond data".into(),
899 });
900 }
901 let query_results_bytes = &data[pos..pos + query_results_len];
902
903 parse_query_result_feather(query_results_bytes)
904}
905
906fn parse_query_result_feather(data: &[u8]) -> Result<BulkQueryResult> {
907 let mut pos: usize = 0;
908
909 pos = consume_magic(data, pos)?;
910
911 let (new_pos, _) = read_length_prefixed_json(data, pos)?;
912 pos = new_pos;
913
914 let (new_pos, _) = read_length_prefixed_json(data, pos)?;
915 pos = new_pos;
916
917 let (new_pos, byte_offset_map) = read_length_prefixed_json(data, pos)?;
918 pos = new_pos;
919
920 let (_query_key, result_len) = byte_offset_map
921 .as_object()
922 .and_then(|m| m.iter().next())
923 .and_then(|(k, v)| v.as_u64().map(|len| (k.clone(), len as usize)))
924 .ok_or_else(|| ChalkClientError::Api {
925 status: 0,
926 message: "empty byte_attrs in query results ByteDict".into(),
927 })?;
928
929 if pos + result_len > data.len() {
930 return Err(ChalkClientError::Api {
931 status: 0,
932 message: "response truncated: result bytes extend beyond data".into(),
933 });
934 }
935 let result_bytes = &data[pos..pos + result_len];
936
937 parse_online_query_result_feather(result_bytes)
938}
939
940fn parse_online_query_result_feather(data: &[u8]) -> Result<BulkQueryResult> {
941 let mut pos: usize = 0;
942
943 pos = consume_magic(data, pos)?;
944
945 let (new_pos, json_attrs) = read_length_prefixed_json(data, pos)?;
946 pos = new_pos;
947
948 let has_data = json_attrs
949 .get("has_data")
950 .and_then(|v| v.as_bool())
951 .unwrap_or(false);
952
953 let meta = json_attrs
954 .get("meta")
955 .and_then(|v| v.as_str())
956 .map(|s| s.to_string());
957
958 let errors: Vec<String> = json_attrs
959 .get("errors")
960 .and_then(|v| v.as_array())
961 .map(|arr| {
962 arr.iter()
963 .filter_map(|v| v.as_str().map(|s| s.to_string()))
964 .collect()
965 })
966 .unwrap_or_default();
967
968 let (new_pos, _) = read_length_prefixed_json(data, pos)?;
969 pos = new_pos;
970
971 let (new_pos, byte_offset_map) = read_length_prefixed_json(data, pos)?;
972 pos = new_pos;
973
974 let scalar_data_len = byte_offset_map
975 .get("scalar_data")
976 .and_then(|v| v.as_u64())
977 .unwrap_or(0) as usize;
978
979 let scalar_data = if scalar_data_len > 0 && pos + scalar_data_len <= data.len() {
980 data[pos..pos + scalar_data_len].to_vec()
981 } else {
982 vec![]
983 };
984
985 Ok(BulkQueryResult {
986 scalar_data,
987 has_data,
988 meta,
989 errors,
990 })
991}
992
993fn consume_magic(data: &[u8], pos: usize) -> Result<usize> {
994 if pos + BYTEMODEL_MAGIC_STR.len() > data.len() {
995 return Err(ChalkClientError::Api {
996 status: 0,
997 message: format!(
998 "response too short for magic string at position {} ({} bytes available)",
999 pos,
1000 data.len() - pos
1001 ),
1002 });
1003 }
1004 if &data[pos..pos + BYTEMODEL_MAGIC_STR.len()] != BYTEMODEL_MAGIC_STR {
1005 return Err(ChalkClientError::Api {
1006 status: 0,
1007 message: format!(
1008 "invalid ByteBaseModel magic at position {} (got {:?})",
1009 pos,
1010 &data[pos..std::cmp::min(pos + BYTEMODEL_MAGIC_STR.len(), data.len())]
1011 ),
1012 });
1013 }
1014 Ok(pos + BYTEMODEL_MAGIC_STR.len())
1015}
1016
1017fn skip_byte_data(data: &[u8], pos: usize, offset_map: &Value) -> Result<usize> {
1018 let total_bytes: usize = offset_map
1019 .as_object()
1020 .map(|m| {
1021 m.values()
1022 .filter_map(|v| v.as_u64())
1023 .map(|v| v as usize)
1024 .sum()
1025 })
1026 .unwrap_or(0);
1027
1028 if pos + total_bytes > data.len() {
1029 return Err(ChalkClientError::Api {
1030 status: 0,
1031 message: format!(
1032 "response truncated: byte data of {} bytes at position {} extends beyond data (total {})",
1033 total_bytes, pos, data.len()
1034 ),
1035 });
1036 }
1037
1038 Ok(pos + total_bytes)
1039}
1040
1041fn read_length_prefixed_json(data: &[u8], pos: usize) -> Result<(usize, Value)> {
1042 if pos + 8 > data.len() {
1043 return Err(ChalkClientError::Api {
1044 status: 0,
1045 message: format!(
1046 "response truncated: expected 8-byte length at position {}, but only {} bytes remain",
1047 pos,
1048 data.len() - pos
1049 ),
1050 });
1051 }
1052
1053 let len = u64::from_be_bytes(data[pos..pos + 8].try_into().unwrap()) as usize;
1054 let json_start = pos + 8;
1055
1056 if json_start + len > data.len() {
1057 return Err(ChalkClientError::Api {
1058 status: 0,
1059 message: format!(
1060 "response truncated: JSON payload of {} bytes at position {} extends beyond data (total {})",
1061 len, json_start, data.len()
1062 ),
1063 });
1064 }
1065
1066 let json_str = std::str::from_utf8(&data[json_start..json_start + len]).map_err(|e| {
1067 ChalkClientError::Api {
1068 status: 0,
1069 message: format!("invalid UTF-8 in response JSON: {}", e),
1070 }
1071 })?;
1072
1073 let value: Value = serde_json::from_str(json_str)?;
1074 Ok((json_start + len, value))
1075}
1076
1077fn build_byte_base_model(
1082 json_attrs: &Value,
1083 byte_attrs: &[(&str, &[u8])],
1084) -> Result<Vec<u8>> {
1085 let json_attrs_bytes = serde_json::to_vec(json_attrs)?;
1086 let empty_json = b"{}";
1087
1088 let byte_offset_map = {
1089 let mut map = serde_json::Map::new();
1090 for (key, data) in byte_attrs {
1091 map.insert((*key).to_string(), Value::Number((data.len() as u64).into()));
1092 }
1093 serde_json::to_vec(&Value::Object(map))?
1094 };
1095
1096 let total_byte_data: usize = byte_attrs.iter().map(|(_, d)| d.len()).sum();
1097
1098 let total_size = BYTEMODEL_MAGIC_STR.len()
1099 + 4 * 8
1100 + json_attrs_bytes.len()
1101 + empty_json.len()
1102 + byte_offset_map.len()
1103 + total_byte_data
1104 + empty_json.len();
1105 let mut buf = Vec::with_capacity(total_size);
1106
1107 buf.extend_from_slice(BYTEMODEL_MAGIC_STR);
1108
1109 buf.extend_from_slice(&(json_attrs_bytes.len() as u64).to_be_bytes());
1110 buf.extend_from_slice(&json_attrs_bytes);
1111
1112 buf.extend_from_slice(&(empty_json.len() as u64).to_be_bytes());
1113 buf.extend_from_slice(empty_json);
1114
1115 buf.extend_from_slice(&(byte_offset_map.len() as u64).to_be_bytes());
1116 buf.extend_from_slice(&byte_offset_map);
1117 for (_, data) in byte_attrs {
1118 buf.extend_from_slice(data);
1119 }
1120
1121 buf.extend_from_slice(&(empty_json.len() as u64).to_be_bytes());
1122 buf.extend_from_slice(empty_json);
1123
1124 Ok(buf)
1125}
1126
1127fn serialize_record_batch_to_feather(batch: &RecordBatch) -> Result<Vec<u8>> {
1132 let mut buf = Vec::new();
1133
1134 {
1135 let mut writer = FileWriter::try_new(&mut buf, &batch.schema())?;
1136 writer.write(batch)?;
1137 writer.finish()?;
1138 }
1139
1140 Ok(buf)
1141}
1142
1143#[cfg(test)]
1147mod tests {
1148 use super::*;
1149 use arrow::array::Int32Array;
1150 use arrow::datatypes::{DataType, Field, Schema};
1151 use std::sync::Arc;
1152
1153 #[test]
1154 fn test_serialize_record_batch_to_feather() {
1155 let schema = Arc::new(Schema::new(vec![Field::new(
1156 "user.id",
1157 DataType::Int32,
1158 false,
1159 )]));
1160 let batch =
1161 RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1, 2, 3]))]).unwrap();
1162
1163 let feather_bytes = serialize_record_batch_to_feather(&batch).unwrap();
1164 assert!(!feather_bytes.is_empty());
1165 assert_eq!(&feather_bytes[..6], b"ARROW1");
1166 }
1167
1168 #[test]
1169 fn test_build_feather_request_body() {
1170 let header = FeatherRequestHeader {
1171 outputs: vec!["user.id".into()],
1172 expression_outputs: vec![],
1173 now: None,
1174 staleness: None,
1175 context: None,
1176 include_meta: true,
1177 explain: false,
1178 correlation_id: None,
1179 query_name: None,
1180 query_name_version: None,
1181 deployment_id: None,
1182 branch_id: None,
1183 meta: None,
1184 store_plan_stages: Some(false),
1185 query_context: None,
1186 encoding_options: FeatureEncodingOptions {
1187 encode_structs_as_objects: None,
1188 },
1189 planner_options: None,
1190 value_metrics_tag_by_features: vec![],
1191 overlay_graph: None,
1192 };
1193
1194 let fake_feather = b"ARROW1fake_feather_data";
1195 let body = build_feather_request_body(&header, fake_feather).unwrap();
1196
1197 assert_eq!(&body[..5], b"chal1");
1198
1199 let header_len = u64::from_be_bytes(body[5..13].try_into().unwrap()) as usize;
1200 assert!(header_len > 0);
1201
1202 let header_json_str = std::str::from_utf8(&body[13..13 + header_len]).unwrap();
1203 let parsed: Value = serde_json::from_str(header_json_str).unwrap();
1204 assert_eq!(parsed["outputs"][0], "user.id");
1205 assert_eq!(parsed["include_meta"], true);
1206
1207 let body_len_start = 13 + header_len;
1208 let body_len =
1209 u64::from_be_bytes(body[body_len_start..body_len_start + 8].try_into().unwrap())
1210 as usize;
1211 assert_eq!(body_len, fake_feather.len());
1212
1213 let body_start = body_len_start + 8;
1214 assert_eq!(&body[body_start..body_start + body_len], fake_feather);
1215 }
1216
1217 #[tokio::test]
1218 async fn test_client_builder() {
1219 let mut server = mockito::Server::new_async().await;
1220
1221 let mock = server
1222 .mock("POST", "/v1/oauth/token")
1223 .with_status(200)
1224 .with_header("content-type", "application/json")
1225 .with_body(
1226 serde_json::json!({
1227 "access_token": "test-jwt",
1228 "expires_in": 3600,
1229 "primary_environment": "env-123",
1230 "engines": {"env-123": server.url()},
1231 "grpc_engines": {},
1232 "environment_id_to_name": {"env-123": "production"}
1233 })
1234 .to_string(),
1235 )
1236 .create_async()
1237 .await;
1238
1239 let client = ChalkClient::new()
1240 .client_id("test-id")
1241 .client_secret("test-secret")
1242 .api_server(&server.url())
1243 .environment("env-123")
1244 .build()
1245 .await
1246 .unwrap();
1247
1248 assert_eq!(client.environment_id(), "env-123");
1249 assert_eq!(client.query_server(), &server.url());
1250 mock.assert_async().await;
1251 }
1252
1253 #[tokio::test]
1254 async fn test_query() {
1255 let mut server = mockito::Server::new_async().await;
1256
1257 server
1258 .mock("POST", "/v1/oauth/token")
1259 .with_status(200)
1260 .with_header("content-type", "application/json")
1261 .with_body(
1262 serde_json::json!({
1263 "access_token": "test-jwt",
1264 "expires_in": 3600,
1265 "primary_environment": "env-1",
1266 "engines": {"env-1": server.url()},
1267 "grpc_engines": {}
1268 })
1269 .to_string(),
1270 )
1271 .create_async()
1272 .await;
1273
1274 let query_mock = server
1275 .mock("POST", "/v1/query/online")
1276 .match_header("Authorization", "Bearer test-jwt")
1277 .match_header("X-Chalk-Env-Id", "env-1")
1278 .with_status(200)
1279 .with_header("content-type", "application/json")
1280 .with_body(
1281 serde_json::json!({
1282 "data": [
1283 {"field": "user.age", "value": 25},
1284 {"field": "user.name", "value": "Alice"}
1285 ],
1286 "errors": []
1287 })
1288 .to_string(),
1289 )
1290 .create_async()
1291 .await;
1292
1293 let client = ChalkClient::new()
1294 .client_id("test-id")
1295 .client_secret("test-secret")
1296 .api_server(&server.url())
1297 .environment("env-1")
1298 .build()
1299 .await
1300 .unwrap();
1301
1302 let inputs = HashMap::from([("user.id".into(), serde_json::json!(42))]);
1303 let outputs = vec!["user.age".into(), "user.name".into()];
1304
1305 let response = client
1306 .query(inputs, outputs, QueryOptions::default())
1307 .await
1308 .unwrap();
1309
1310 assert_eq!(response.data.len(), 2);
1311 assert_eq!(response.data[0].field, "user.age");
1312 assert_eq!(response.data[0].value, serde_json::json!(25));
1313 assert_eq!(response.data[1].field, "user.name");
1314 assert_eq!(response.data[1].value, serde_json::json!("Alice"));
1315
1316 query_mock.assert_async().await;
1317 }
1318
1319 #[tokio::test]
1320 async fn test_query_api_error() {
1321 let mut server = mockito::Server::new_async().await;
1322
1323 server
1324 .mock("POST", "/v1/oauth/token")
1325 .with_status(200)
1326 .with_header("content-type", "application/json")
1327 .with_body(
1328 serde_json::json!({
1329 "access_token": "jwt",
1330 "expires_in": 3600,
1331 "primary_environment": "env-1",
1332 "engines": {"env-1": server.url()},
1333 "grpc_engines": {}
1334 })
1335 .to_string(),
1336 )
1337 .create_async()
1338 .await;
1339
1340 server
1341 .mock("POST", "/v1/query/online")
1342 .with_status(500)
1343 .with_body("internal server error")
1344 .create_async()
1345 .await;
1346
1347 let client = ChalkClient::new()
1348 .client_id("id")
1349 .client_secret("secret")
1350 .api_server(&server.url())
1351 .environment("env-1")
1352 .build()
1353 .await
1354 .unwrap();
1355
1356 let result = client
1357 .query(HashMap::new(), vec![], QueryOptions::default())
1358 .await;
1359
1360 assert!(result.is_err());
1361 match result.unwrap_err() {
1362 ChalkClientError::Api { status, message } => {
1363 assert_eq!(status, 500);
1364 assert!(message.contains("internal server error"));
1365 }
1366 other => panic!("expected Api error, got: {:?}", other),
1367 }
1368 }
1369
1370 #[tokio::test]
1371 async fn test_offline_query() {
1372 let mut server = mockito::Server::new_async().await;
1373
1374 server
1375 .mock("POST", "/v1/oauth/token")
1376 .with_status(200)
1377 .with_header("content-type", "application/json")
1378 .with_body(
1379 serde_json::json!({
1380 "access_token": "jwt",
1381 "expires_in": 3600,
1382 "primary_environment": "env-1",
1383 "engines": {"env-1": server.url()},
1384 "grpc_engines": {}
1385 })
1386 .to_string(),
1387 )
1388 .create_async()
1389 .await;
1390
1391 let offline_mock = server
1392 .mock("POST", "/v4/offline_query")
1393 .match_header("Authorization", "Bearer jwt")
1394 .with_status(200)
1395 .with_header("content-type", "application/json")
1396 .with_body(
1397 serde_json::json!({
1398 "is_finished": false,
1399 "dataset_id": "ds-123",
1400 "revisions": [{
1401 "revision_id": "rev-1",
1402 "status": "pending"
1403 }],
1404 "errors": []
1405 })
1406 .to_string(),
1407 )
1408 .create_async()
1409 .await;
1410
1411 let client = ChalkClient::new()
1412 .client_id("id")
1413 .client_secret("secret")
1414 .api_server(&server.url())
1415 .environment("env-1")
1416 .build()
1417 .await
1418 .unwrap();
1419
1420 let request = OfflineQueryRequest {
1421 input: None,
1422 output: vec!["user.ltv".into()],
1423 destination_format: Some("PARQUET".into()),
1424 job_id: None,
1425 max_samples: None,
1426 max_cache_age_secs: None,
1427 observed_at_lower_bound: None,
1428 observed_at_upper_bound: None,
1429 dataset_name: None,
1430 branch: None,
1431 recompute_features: None,
1432 tags: None,
1433 required_resolver_tags: None,
1434 correlation_id: None,
1435 store_online: None,
1436 store_offline: None,
1437 required_output: None,
1438 run_asynchronously: None,
1439 num_shards: None,
1440 num_workers: None,
1441 resources: None,
1442 completion_deadline: None,
1443 max_retries: None,
1444 store_plan_stages: None,
1445 explain: None,
1446 planner_options: None,
1447 query_context: None,
1448 use_multiple_computers: None,
1449 spine_sql_query: None,
1450 query_name: None,
1451 query_name_version: None,
1452 };
1453
1454 let response = client.offline_query_raw(request).await.unwrap();
1455 assert!(!response.is_finished);
1456 assert_eq!(response.dataset_id.as_deref(), Some("ds-123"));
1457 assert_eq!(response.revisions.len(), 1);
1458
1459 offline_mock.assert_async().await;
1460 }
1461
1462 #[tokio::test]
1463 async fn test_offline_query_with_builder() {
1464 let mut server = mockito::Server::new_async().await;
1465
1466 server
1467 .mock("POST", "/v1/oauth/token")
1468 .with_status(200)
1469 .with_header("content-type", "application/json")
1470 .with_body(
1471 serde_json::json!({
1472 "access_token": "jwt",
1473 "expires_in": 3600,
1474 "primary_environment": "env-1",
1475 "engines": {"env-1": server.url()},
1476 "grpc_engines": {}
1477 })
1478 .to_string(),
1479 )
1480 .create_async()
1481 .await;
1482
1483 let offline_mock = server
1484 .mock("POST", "/v4/offline_query")
1485 .with_status(200)
1486 .with_header("content-type", "application/json")
1487 .with_body(
1488 serde_json::json!({
1489 "is_finished": false,
1490 "dataset_id": "ds-456",
1491 "revisions": [{
1492 "revision_id": "rev-2",
1493 "status": "pending"
1494 }],
1495 "errors": []
1496 })
1497 .to_string(),
1498 )
1499 .create_async()
1500 .await;
1501
1502 let client = ChalkClient::new()
1503 .client_id("id")
1504 .client_secret("secret")
1505 .api_server(&server.url())
1506 .environment("env-1")
1507 .build()
1508 .await
1509 .unwrap();
1510
1511 use crate::offline::OfflineQueryParams;
1512
1513 let response = client
1514 .offline_query(
1515 OfflineQueryParams::new()
1516 .with_input("user.id", vec![serde_json::json!(1), serde_json::json!(2)])
1517 .with_output("user.email")
1518 .with_output("user.ltv")
1519 .with_num_shards(4),
1520 )
1521 .await
1522 .unwrap();
1523
1524 assert!(!response.is_finished);
1525 assert_eq!(response.dataset_id.as_deref(), Some("ds-456"));
1526 offline_mock.assert_async().await;
1527 }
1528
1529 #[tokio::test]
1530 async fn test_wait_for_offline_query_success() {
1531 let mut server = mockito::Server::new_async().await;
1532
1533 server
1534 .mock("POST", "/v1/oauth/token")
1535 .with_status(200)
1536 .with_header("content-type", "application/json")
1537 .with_body(
1538 serde_json::json!({
1539 "access_token": "jwt",
1540 "expires_in": 3600,
1541 "primary_environment": "env-1",
1542 "engines": {"env-1": server.url()},
1543 "grpc_engines": {}
1544 })
1545 .to_string(),
1546 )
1547 .create_async()
1548 .await;
1549
1550 server
1551 .mock("GET", "/v4/offline_query/rev-1/status")
1552 .with_status(200)
1553 .with_header("content-type", "application/json")
1554 .with_body(
1555 serde_json::json!({
1556 "report": {
1557 "status": "RUNNING"
1558 }
1559 })
1560 .to_string(),
1561 )
1562 .create_async()
1563 .await;
1564
1565 server
1566 .mock("GET", "/v4/offline_query/rev-1/status")
1567 .with_status(200)
1568 .with_header("content-type", "application/json")
1569 .with_body(
1570 serde_json::json!({
1571 "report": {
1572 "status": "COMPLETED"
1573 }
1574 })
1575 .to_string(),
1576 )
1577 .create_async()
1578 .await;
1579
1580 let client = ChalkClient::new()
1581 .client_id("id")
1582 .client_secret("secret")
1583 .api_server(&server.url())
1584 .environment("env-1")
1585 .build()
1586 .await
1587 .unwrap();
1588
1589 let response = OfflineQueryResponse {
1590 is_finished: false,
1591 version: None,
1592 dataset_id: Some("ds-123".into()),
1593 dataset_name: None,
1594 environment_id: None,
1595 revisions: vec![crate::types::DatasetRevision {
1596 revision_id: Some("rev-1".into()),
1597 creator_id: None,
1598 environment_id: None,
1599 outputs: vec![],
1600 status: Some("pending".into()),
1601 num_partitions: None,
1602 output_uris: None,
1603 created_at: None,
1604 started_at: None,
1605 terminated_at: None,
1606 dashboard_url: None,
1607 dataset_name: None,
1608 dataset_id: None,
1609 branch: None,
1610 }],
1611 errors: vec![],
1612 };
1613
1614 let result = client
1615 .wait_for_offline_query(&response, Some(Duration::from_secs(5)))
1616 .await;
1617 assert!(result.is_ok());
1618 }
1619
1620 #[tokio::test]
1621 async fn test_wait_for_offline_query_failure() {
1622 let mut server = mockito::Server::new_async().await;
1623
1624 server
1625 .mock("POST", "/v1/oauth/token")
1626 .with_status(200)
1627 .with_header("content-type", "application/json")
1628 .with_body(
1629 serde_json::json!({
1630 "access_token": "jwt",
1631 "expires_in": 3600,
1632 "primary_environment": "env-1",
1633 "engines": {"env-1": server.url()},
1634 "grpc_engines": {}
1635 })
1636 .to_string(),
1637 )
1638 .create_async()
1639 .await;
1640
1641 server
1642 .mock("GET", "/v4/offline_query/rev-1/status")
1643 .with_status(200)
1644 .with_header("content-type", "application/json")
1645 .with_body(
1646 serde_json::json!({
1647 "report": {
1648 "status": "FAILED",
1649 "all_errors": [{
1650 "code": "INTERNAL_ERROR",
1651 "category": "REQUEST",
1652 "message": "job failed due to OOM"
1653 }]
1654 }
1655 })
1656 .to_string(),
1657 )
1658 .create_async()
1659 .await;
1660
1661 let client = ChalkClient::new()
1662 .client_id("id")
1663 .client_secret("secret")
1664 .api_server(&server.url())
1665 .environment("env-1")
1666 .build()
1667 .await
1668 .unwrap();
1669
1670 let response = OfflineQueryResponse {
1671 is_finished: false,
1672 version: None,
1673 dataset_id: None,
1674 dataset_name: None,
1675 environment_id: None,
1676 revisions: vec![crate::types::DatasetRevision {
1677 revision_id: Some("rev-1".into()),
1678 creator_id: None,
1679 environment_id: None,
1680 outputs: vec![],
1681 status: None,
1682 num_partitions: None,
1683 output_uris: None,
1684 created_at: None,
1685 started_at: None,
1686 terminated_at: None,
1687 dashboard_url: None,
1688 dataset_name: None,
1689 dataset_id: None,
1690 branch: None,
1691 }],
1692 errors: vec![],
1693 };
1694
1695 let result = client
1696 .wait_for_offline_query(&response, Some(Duration::from_secs(5)))
1697 .await;
1698 assert!(result.is_err());
1699 let err = result.unwrap_err().to_string();
1700 assert!(err.contains("OOM"));
1701 }
1702
1703 #[tokio::test]
1704 async fn test_get_offline_query_download_urls() {
1705 let mut server = mockito::Server::new_async().await;
1706
1707 server
1708 .mock("POST", "/v1/oauth/token")
1709 .with_status(200)
1710 .with_header("content-type", "application/json")
1711 .with_body(
1712 serde_json::json!({
1713 "access_token": "jwt",
1714 "expires_in": 3600,
1715 "primary_environment": "env-1",
1716 "engines": {"env-1": server.url()},
1717 "grpc_engines": {}
1718 })
1719 .to_string(),
1720 )
1721 .create_async()
1722 .await;
1723
1724 server
1725 .mock("GET", "/v2/offline_query/rev-1")
1726 .with_status(200)
1727 .with_header("content-type", "application/json")
1728 .with_body(
1729 serde_json::json!({
1730 "is_finished": false,
1731 "urls": [],
1732 "errors": []
1733 })
1734 .to_string(),
1735 )
1736 .create_async()
1737 .await;
1738
1739 server
1740 .mock("GET", "/v2/offline_query/rev-1")
1741 .with_status(200)
1742 .with_header("content-type", "application/json")
1743 .with_body(
1744 serde_json::json!({
1745 "is_finished": true,
1746 "urls": [
1747 "https://storage.example.com/results/part-0.parquet",
1748 "https://storage.example.com/results/part-1.parquet"
1749 ],
1750 "errors": []
1751 })
1752 .to_string(),
1753 )
1754 .create_async()
1755 .await;
1756
1757 let client = ChalkClient::new()
1758 .client_id("id")
1759 .client_secret("secret")
1760 .api_server(&server.url())
1761 .environment("env-1")
1762 .build()
1763 .await
1764 .unwrap();
1765
1766 let response = OfflineQueryResponse {
1767 is_finished: false,
1768 version: None,
1769 dataset_id: None,
1770 dataset_name: None,
1771 environment_id: None,
1772 revisions: vec![crate::types::DatasetRevision {
1773 revision_id: Some("rev-1".into()),
1774 creator_id: None,
1775 environment_id: None,
1776 outputs: vec![],
1777 status: None,
1778 num_partitions: None,
1779 output_uris: None,
1780 created_at: None,
1781 started_at: None,
1782 terminated_at: None,
1783 dashboard_url: None,
1784 dataset_name: None,
1785 dataset_id: None,
1786 branch: None,
1787 }],
1788 errors: vec![],
1789 };
1790
1791 let urls = client
1792 .get_offline_query_download_urls(&response, Some(Duration::from_secs(5)))
1793 .await
1794 .unwrap();
1795
1796 assert_eq!(urls.len(), 2);
1797 assert!(urls[0].contains("part-0.parquet"));
1798 assert!(urls[1].contains("part-1.parquet"));
1799 }
1800
1801 #[tokio::test]
1802 async fn test_wait_for_offline_query_timeout() {
1803 let mut server = mockito::Server::new_async().await;
1804
1805 server
1806 .mock("POST", "/v1/oauth/token")
1807 .with_status(200)
1808 .with_header("content-type", "application/json")
1809 .with_body(
1810 serde_json::json!({
1811 "access_token": "jwt",
1812 "expires_in": 3600,
1813 "primary_environment": "env-1",
1814 "engines": {"env-1": server.url()},
1815 "grpc_engines": {}
1816 })
1817 .to_string(),
1818 )
1819 .create_async()
1820 .await;
1821
1822 server
1823 .mock("GET", "/v4/offline_query/rev-1/status")
1824 .with_status(200)
1825 .with_header("content-type", "application/json")
1826 .with_body(
1827 serde_json::json!({
1828 "report": {
1829 "status": "RUNNING"
1830 }
1831 })
1832 .to_string(),
1833 )
1834 .expect_at_least(1)
1835 .create_async()
1836 .await;
1837
1838 let client = ChalkClient::new()
1839 .client_id("id")
1840 .client_secret("secret")
1841 .api_server(&server.url())
1842 .environment("env-1")
1843 .build()
1844 .await
1845 .unwrap();
1846
1847 let response = OfflineQueryResponse {
1848 is_finished: false,
1849 version: None,
1850 dataset_id: None,
1851 dataset_name: None,
1852 environment_id: None,
1853 revisions: vec![crate::types::DatasetRevision {
1854 revision_id: Some("rev-1".into()),
1855 creator_id: None,
1856 environment_id: None,
1857 outputs: vec![],
1858 status: None,
1859 num_partitions: None,
1860 output_uris: None,
1861 created_at: None,
1862 started_at: None,
1863 terminated_at: None,
1864 dashboard_url: None,
1865 dataset_name: None,
1866 dataset_id: None,
1867 branch: None,
1868 }],
1869 errors: vec![],
1870 };
1871
1872 let result = client
1873 .wait_for_offline_query(&response, Some(Duration::from_millis(500)))
1874 .await;
1875 assert!(result.is_err());
1876 let err = result.unwrap_err().to_string();
1877 assert!(err.contains("timed out"));
1878 }
1879
1880 #[test]
1881 fn test_build_byte_base_model() {
1882 let json_attrs = serde_json::json!({
1883 "features": ["user.id", "user.age"],
1884 "table_compression": "uncompressed",
1885 });
1886 let fake_arrow = b"ARROW1fake_data_here";
1887
1888 let body = build_byte_base_model(&json_attrs, &[("table_bytes", fake_arrow.as_slice())])
1889 .unwrap();
1890
1891 let mut pos = 0;
1892
1893 assert_eq!(
1894 &body[pos..pos + BYTEMODEL_MAGIC_STR.len()],
1895 BYTEMODEL_MAGIC_STR
1896 );
1897 pos += BYTEMODEL_MAGIC_STR.len();
1898
1899 let json_attrs_len =
1900 u64::from_be_bytes(body[pos..pos + 8].try_into().unwrap()) as usize;
1901 pos += 8;
1902 let json_attrs_parsed: Value =
1903 serde_json::from_slice(&body[pos..pos + json_attrs_len]).unwrap();
1904 assert_eq!(json_attrs_parsed["features"][0], "user.id");
1905 assert_eq!(json_attrs_parsed["table_compression"], "uncompressed");
1906 pos += json_attrs_len;
1907
1908 let pydantic_len =
1909 u64::from_be_bytes(body[pos..pos + 8].try_into().unwrap()) as usize;
1910 pos += 8;
1911 let pydantic: Value =
1912 serde_json::from_slice(&body[pos..pos + pydantic_len]).unwrap();
1913 assert_eq!(pydantic, serde_json::json!({}));
1914 pos += pydantic_len;
1915
1916 let byte_map_len =
1917 u64::from_be_bytes(body[pos..pos + 8].try_into().unwrap()) as usize;
1918 pos += 8;
1919 let byte_map: Value =
1920 serde_json::from_slice(&body[pos..pos + byte_map_len]).unwrap();
1921 assert_eq!(byte_map["table_bytes"], fake_arrow.len() as u64);
1922 pos += byte_map_len;
1923
1924 assert_eq!(&body[pos..pos + fake_arrow.len()], fake_arrow);
1925 pos += fake_arrow.len();
1926
1927 let ser_len =
1928 u64::from_be_bytes(body[pos..pos + 8].try_into().unwrap()) as usize;
1929 pos += 8;
1930 let ser: Value = serde_json::from_slice(&body[pos..pos + ser_len]).unwrap();
1931 assert_eq!(ser, serde_json::json!({}));
1932 pos += ser_len;
1933
1934 assert_eq!(pos, body.len());
1935 }
1936
1937 #[tokio::test]
1938 async fn test_upload_features() {
1939 let mut server = mockito::Server::new_async().await;
1940
1941 server
1942 .mock("POST", "/v1/oauth/token")
1943 .with_status(200)
1944 .with_header("content-type", "application/json")
1945 .with_body(
1946 serde_json::json!({
1947 "access_token": "jwt",
1948 "expires_in": 3600,
1949 "primary_environment": "env-1",
1950 "engines": {"env-1": server.url()},
1951 "grpc_engines": {}
1952 })
1953 .to_string(),
1954 )
1955 .create_async()
1956 .await;
1957
1958 let upload_mock = server
1959 .mock("POST", "/v1/upload_features/multi")
1960 .match_header("Authorization", "Bearer jwt")
1961 .match_header("Content-Type", "application/octet-stream")
1962 .with_status(200)
1963 .with_header("content-type", "application/json")
1964 .with_body(
1965 serde_json::json!({
1966 "operation_id": "op-abc-123",
1967 "errors": []
1968 })
1969 .to_string(),
1970 )
1971 .create_async()
1972 .await;
1973
1974 let client = ChalkClient::new()
1975 .client_id("id")
1976 .client_secret("secret")
1977 .api_server(&server.url())
1978 .environment("env-1")
1979 .build()
1980 .await
1981 .unwrap();
1982
1983 let schema = Arc::new(Schema::new(vec![
1984 Field::new("user.id", DataType::Int32, false),
1985 Field::new("user.age", DataType::Int32, true),
1986 ]));
1987 let batch = RecordBatch::try_new(
1988 schema,
1989 vec![
1990 Arc::new(Int32Array::from(vec![1, 2, 3])),
1991 Arc::new(Int32Array::from(vec![25, 30, 22])),
1992 ],
1993 )
1994 .unwrap();
1995
1996 let result = client.upload_features(&batch).await.unwrap();
1997 assert_eq!(result.operation_id.as_deref(), Some("op-abc-123"));
1998 assert!(result.errors.is_empty());
1999
2000 upload_mock.assert_async().await;
2001 }
2002
2003 #[tokio::test]
2004 async fn test_upload_features_map() {
2005 let mut server = mockito::Server::new_async().await;
2006
2007 server
2008 .mock("POST", "/v1/oauth/token")
2009 .with_status(200)
2010 .with_header("content-type", "application/json")
2011 .with_body(
2012 serde_json::json!({
2013 "access_token": "jwt",
2014 "expires_in": 3600,
2015 "primary_environment": "env-1",
2016 "engines": {"env-1": server.url()},
2017 "grpc_engines": {}
2018 })
2019 .to_string(),
2020 )
2021 .create_async()
2022 .await;
2023
2024 let upload_mock = server
2025 .mock("POST", "/v1/upload_features/multi")
2026 .with_status(200)
2027 .with_header("content-type", "application/json")
2028 .with_body(
2029 serde_json::json!({
2030 "operation_id": "op-map-456",
2031 "errors": []
2032 })
2033 .to_string(),
2034 )
2035 .create_async()
2036 .await;
2037
2038 let client = ChalkClient::new()
2039 .client_id("id")
2040 .client_secret("secret")
2041 .api_server(&server.url())
2042 .environment("env-1")
2043 .build()
2044 .await
2045 .unwrap();
2046
2047 let inputs = HashMap::from([
2048 (
2049 "user.id".to_string(),
2050 vec![serde_json::json!(1), serde_json::json!(2)],
2051 ),
2052 (
2053 "user.name".to_string(),
2054 vec![serde_json::json!("Alice"), serde_json::json!("Bob")],
2055 ),
2056 ]);
2057
2058 let result = client.upload_features_map(inputs).await.unwrap();
2059 assert_eq!(result.operation_id.as_deref(), Some("op-map-456"));
2060
2061 upload_mock.assert_async().await;
2062 }
2063
2064 #[tokio::test]
2065 async fn test_upload_features_map_empty_inputs() {
2066 let mut server = mockito::Server::new_async().await;
2067
2068 server
2069 .mock("POST", "/v1/oauth/token")
2070 .with_status(200)
2071 .with_header("content-type", "application/json")
2072 .with_body(
2073 serde_json::json!({
2074 "access_token": "jwt",
2075 "expires_in": 3600,
2076 "primary_environment": "env-1",
2077 "engines": {"env-1": server.url()},
2078 "grpc_engines": {}
2079 })
2080 .to_string(),
2081 )
2082 .create_async()
2083 .await;
2084
2085 let client = ChalkClient::new()
2086 .client_id("id")
2087 .client_secret("secret")
2088 .api_server(&server.url())
2089 .environment("env-1")
2090 .build()
2091 .await
2092 .unwrap();
2093
2094 let result = client.upload_features_map(HashMap::new()).await;
2095 assert!(result.is_err());
2096 let err = result.unwrap_err().to_string();
2097 assert!(err.contains("at least one feature"));
2098 }
2099}