1use std::collections::HashMap;
7
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Deserializer, Serialize};
10use serde_json::Value;
11
12fn deserialize_null_as_empty_vec<'de, D, T>(deserializer: D) -> Result<Vec<T>, D::Error>
14where
15 D: Deserializer<'de>,
16 T: Deserialize<'de>,
17{
18 let opt = Option::<Vec<T>>::deserialize(deserializer)?;
19 Ok(opt.unwrap_or_default())
20}
21
22fn deserialize_status_flexible<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
24where
25 D: Deserializer<'de>,
26{
27 let v = Option::<Value>::deserialize(deserializer)?;
28 match v {
29 None | Some(Value::Null) => Ok(None),
30 Some(Value::String(s)) => Ok(Some(s)),
31 Some(Value::Number(n)) => {
32 let label = match n.as_i64() {
34 Some(1) => "pending_submission",
35 Some(2) => "submitted",
36 Some(3) => "running",
37 Some(4) => "error",
38 Some(5) => "expired",
39 Some(6) => "cancelled",
40 Some(7) => "successful",
41 _ => "unknown",
42 };
43 Ok(Some(label.to_string()))
44 }
45 Some(other) => Ok(Some(other.to_string())),
46 }
47}
48
49#[derive(Debug, Clone, Serialize)]
55pub struct OnlineQueryRequest {
56 pub inputs: HashMap<String, Value>,
58
59 pub outputs: Vec<String>,
61
62 #[serde(skip_serializing_if = "Option::is_none")]
64 pub context: Option<OnlineQueryContext>,
65
66 #[serde(skip_serializing_if = "Option::is_none")]
68 pub staleness: Option<HashMap<String, String>>,
69
70 #[serde(skip_serializing_if = "Option::is_none")]
72 pub include_meta: Option<bool>,
73
74 #[serde(skip_serializing_if = "Option::is_none")]
76 pub query_name: Option<String>,
77
78 #[serde(skip_serializing_if = "Option::is_none")]
80 pub correlation_id: Option<String>,
81
82 #[serde(skip_serializing_if = "Option::is_none")]
84 pub query_context: Option<HashMap<String, Value>>,
85
86 #[serde(skip_serializing_if = "Option::is_none")]
88 pub meta: Option<HashMap<String, String>>,
89
90 #[serde(skip_serializing_if = "Option::is_none")]
92 pub query_name_version: Option<String>,
93
94 #[serde(skip_serializing_if = "Option::is_none")]
96 pub now: Option<String>,
97
98 #[serde(skip_serializing_if = "Option::is_none")]
100 pub explain: Option<bool>,
101
102 #[serde(skip_serializing_if = "Option::is_none")]
104 pub store_plan_stages: Option<bool>,
105
106 #[serde(skip_serializing_if = "Option::is_none")]
108 pub encoding_options: Option<FeatureEncodingOptions>,
109
110 #[serde(skip_serializing_if = "Option::is_none")]
112 pub branch_id: Option<String>,
113}
114
115#[derive(Debug, Clone, Default, Serialize, Deserialize)]
117pub struct OnlineQueryContext {
118 #[serde(skip_serializing_if = "Option::is_none")]
119 pub tags: Option<Vec<String>>,
120
121 #[serde(skip_serializing_if = "Option::is_none")]
122 pub required_resolver_tags: Option<Vec<String>>,
123}
124
125#[derive(Debug, Clone, Default, Serialize, Deserialize)]
127pub struct FeatureEncodingOptions {
128 #[serde(skip_serializing_if = "Option::is_none")]
129 pub encode_structs_as_objects: Option<bool>,
130}
131
132#[derive(Debug, Clone, Deserialize)]
138pub struct OnlineQueryResponse {
139 pub data: Vec<FeatureResult>,
140
141 #[serde(default, deserialize_with = "deserialize_null_as_empty_vec")]
142 pub errors: Vec<ChalkError>,
143
144 #[serde(default)]
145 pub meta: Option<QueryMeta>,
146}
147
148#[derive(Debug, Clone, Deserialize)]
150pub struct FeatureResult {
151 pub field: String,
152 pub value: Value,
153
154 #[serde(default)]
155 pub pkey: Option<Value>,
156
157 #[serde(default)]
158 pub ts: Option<String>,
159
160 #[serde(default)]
161 pub meta: Option<FeatureMeta>,
162
163 #[serde(default)]
164 pub error: Option<ChalkError>,
165}
166
167#[derive(Debug, Clone, Deserialize)]
169pub struct FeatureMeta {
170 #[serde(default)]
171 pub chosen_resolver_fqn: Option<String>,
172
173 #[serde(default)]
174 pub cache_hit: Option<bool>,
175
176 #[serde(default)]
177 pub primitive_type: Option<String>,
178
179 #[serde(default)]
180 pub version: Option<i64>,
181}
182
183#[derive(Debug, Clone, Deserialize)]
185pub struct QueryMeta {
186 #[serde(default)]
187 pub execution_duration_s: Option<f64>,
188
189 #[serde(default)]
190 pub deployment_id: Option<String>,
191
192 #[serde(default)]
193 pub environment_id: Option<String>,
194
195 #[serde(default)]
196 pub environment_name: Option<String>,
197
198 #[serde(default)]
199 pub query_id: Option<String>,
200
201 #[serde(default)]
202 pub query_timestamp: Option<DateTime<Utc>>,
203
204 #[serde(default)]
205 pub query_hash: Option<String>,
206}
207
208#[derive(Debug, Clone, Serialize)]
214pub struct OfflineQueryRequest {
215 #[serde(skip_serializing_if = "Option::is_none")]
216 pub input: Option<OfflineQueryInputType>,
217
218 pub output: Vec<String>,
219
220 #[serde(skip_serializing_if = "Option::is_none")]
221 pub destination_format: Option<String>,
222
223 #[serde(skip_serializing_if = "Option::is_none")]
224 pub job_id: Option<String>,
225
226 #[serde(skip_serializing_if = "Option::is_none")]
227 pub max_samples: Option<i64>,
228
229 #[serde(skip_serializing_if = "Option::is_none")]
230 pub max_cache_age_secs: Option<i64>,
231
232 #[serde(skip_serializing_if = "Option::is_none")]
233 pub observed_at_lower_bound: Option<String>,
234
235 #[serde(skip_serializing_if = "Option::is_none")]
236 pub observed_at_upper_bound: Option<String>,
237
238 #[serde(skip_serializing_if = "Option::is_none")]
239 pub dataset_name: Option<String>,
240
241 #[serde(skip_serializing_if = "Option::is_none")]
242 pub branch: Option<String>,
243
244 #[serde(skip_serializing_if = "Option::is_none")]
245 pub recompute_features: Option<Value>,
246
247 #[serde(skip_serializing_if = "Option::is_none")]
248 pub tags: Option<Vec<String>>,
249
250 #[serde(skip_serializing_if = "Option::is_none")]
251 pub required_resolver_tags: Option<Vec<String>>,
252
253 #[serde(skip_serializing_if = "Option::is_none")]
254 pub correlation_id: Option<String>,
255
256 #[serde(skip_serializing_if = "Option::is_none")]
257 pub store_online: Option<bool>,
258
259 #[serde(skip_serializing_if = "Option::is_none")]
260 pub store_offline: Option<bool>,
261
262 #[serde(skip_serializing_if = "Option::is_none")]
263 pub required_output: Option<Vec<String>>,
264
265 #[serde(skip_serializing_if = "Option::is_none")]
266 pub run_asynchronously: Option<bool>,
267
268 #[serde(skip_serializing_if = "Option::is_none")]
269 pub num_shards: Option<i64>,
270
271 #[serde(skip_serializing_if = "Option::is_none")]
272 pub num_workers: Option<i64>,
273
274 #[serde(skip_serializing_if = "Option::is_none")]
275 pub resources: Option<ResourceRequests>,
276
277 #[serde(skip_serializing_if = "Option::is_none")]
278 pub completion_deadline: Option<String>,
279
280 #[serde(skip_serializing_if = "Option::is_none")]
281 pub max_retries: Option<i64>,
282
283 #[serde(skip_serializing_if = "Option::is_none")]
284 pub store_plan_stages: Option<bool>,
285
286 #[serde(skip_serializing_if = "Option::is_none")]
287 pub explain: Option<bool>,
288
289 #[serde(skip_serializing_if = "Option::is_none")]
290 pub planner_options: Option<HashMap<String, Value>>,
291
292 #[serde(skip_serializing_if = "Option::is_none")]
293 pub query_context: Option<HashMap<String, Value>>,
294
295 #[serde(skip_serializing_if = "Option::is_none")]
296 pub use_multiple_computers: Option<bool>,
297
298 #[serde(skip_serializing_if = "Option::is_none")]
299 pub spine_sql_query: Option<String>,
300
301 #[serde(skip_serializing_if = "Option::is_none")]
302 pub query_name: Option<String>,
303
304 #[serde(skip_serializing_if = "Option::is_none")]
305 pub query_name_version: Option<String>,
306}
307
308#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct OfflineQueryInput {
311 pub columns: Vec<String>,
312 pub values: Vec<Vec<Value>>,
313}
314
315#[derive(Debug, Clone, Serialize, Deserialize)]
317#[serde(untagged)]
318pub enum OfflineQueryInputType {
319 Inline(OfflineQueryInput),
320 Uri(OfflineQueryInputUri),
321 Sql(OfflineQueryInputSql),
322}
323
324#[derive(Debug, Clone, Serialize, Deserialize)]
326pub struct OfflineQueryInputUri {
327 pub parquet_uri: String,
328 #[serde(skip_serializing_if = "Option::is_none")]
329 pub start_row: Option<i64>,
330 #[serde(skip_serializing_if = "Option::is_none")]
331 pub end_row: Option<i64>,
332}
333
334#[derive(Debug, Clone, Serialize, Deserialize)]
336pub struct OfflineQueryInputSql {
337 pub input_sql: String,
338}
339
340#[derive(Debug, Clone, Default, Serialize, Deserialize)]
342pub struct ResourceRequests {
343 #[serde(skip_serializing_if = "Option::is_none")]
344 pub cpu: Option<String>,
345 #[serde(skip_serializing_if = "Option::is_none")]
346 pub memory: Option<String>,
347 #[serde(skip_serializing_if = "Option::is_none")]
348 pub ephemeral_storage: Option<String>,
349}
350
351#[derive(Debug, Clone, Deserialize)]
357pub struct OfflineQueryResponse {
358 #[serde(default)]
359 pub is_finished: bool,
360
361 #[serde(default)]
362 pub version: Option<i64>,
363
364 #[serde(default)]
365 pub dataset_id: Option<String>,
366
367 #[serde(default)]
368 pub dataset_name: Option<String>,
369
370 #[serde(default)]
371 pub environment_id: Option<String>,
372
373 #[serde(default, deserialize_with = "deserialize_null_as_empty_vec")]
374 pub revisions: Vec<DatasetRevision>,
375
376 #[serde(default, deserialize_with = "deserialize_null_as_empty_vec")]
377 pub errors: Vec<ChalkError>,
378}
379
380#[derive(Debug, Clone, Deserialize)]
382pub struct DatasetRevision {
383 #[serde(default)]
384 pub revision_id: Option<String>,
385
386 #[serde(default)]
387 pub creator_id: Option<String>,
388
389 #[serde(default)]
390 pub environment_id: Option<String>,
391
392 #[serde(default)]
393 pub outputs: Vec<String>,
394
395 #[serde(default, deserialize_with = "deserialize_status_flexible")]
396 pub status: Option<String>,
397
398 #[serde(default)]
399 pub num_partitions: Option<i64>,
400
401 #[serde(default)]
402 pub output_uris: Option<String>,
403
404 #[serde(default)]
405 pub created_at: Option<DateTime<Utc>>,
406
407 #[serde(default)]
408 pub started_at: Option<DateTime<Utc>>,
409
410 #[serde(default)]
411 pub terminated_at: Option<DateTime<Utc>>,
412
413 #[serde(default)]
414 pub dashboard_url: Option<String>,
415
416 #[serde(default)]
417 pub dataset_name: Option<String>,
418
419 #[serde(default)]
420 pub dataset_id: Option<String>,
421
422 #[serde(default)]
423 pub branch: Option<String>,
424}
425
426#[derive(Debug, Clone, Deserialize)]
432pub struct GetOfflineQueryStatusResponse {
433 pub report: Option<BatchReport>,
434}
435
436#[derive(Debug, Clone, Deserialize)]
438pub struct BatchReport {
439 #[serde(default)]
440 pub operation_id: Option<String>,
441
442 #[serde(default)]
443 pub status: Option<String>,
444
445 #[serde(default)]
446 pub environment_id: Option<String>,
447
448 #[serde(default)]
449 pub error: Option<ChalkError>,
450
451 #[serde(default, deserialize_with = "deserialize_null_as_empty_vec")]
452 pub all_errors: Vec<ChalkError>,
453}
454
455#[derive(Debug, Clone, Deserialize)]
457pub struct GetOfflineQueryJobResponse {
458 pub is_finished: bool,
459
460 #[serde(default)]
461 pub version: Option<i64>,
462
463 #[serde(default, deserialize_with = "deserialize_null_as_empty_vec")]
464 pub urls: Vec<String>,
465
466 #[serde(default, deserialize_with = "deserialize_null_as_empty_vec")]
467 pub errors: Vec<ChalkError>,
468}
469
470#[derive(Debug, Clone, Deserialize)]
476pub struct UploadFeaturesResult {
477 #[serde(default)]
478 pub operation_id: Option<String>,
479
480 #[serde(default, deserialize_with = "deserialize_null_as_empty_vec")]
481 pub errors: Vec<ChalkError>,
482}
483
484#[derive(Debug, Clone, Serialize, Deserialize)]
490pub struct ChalkError {
491 pub code: String,
492 pub category: String,
493 pub message: String,
494
495 #[serde(skip_serializing_if = "Option::is_none")]
496 pub feature: Option<String>,
497
498 #[serde(skip_serializing_if = "Option::is_none")]
499 pub resolver: Option<String>,
500
501 #[serde(skip_serializing_if = "Option::is_none")]
502 pub exception: Option<ResolverException>,
503}
504
505#[derive(Debug, Clone, Serialize, Deserialize)]
507pub struct ResolverException {
508 #[serde(default)]
509 pub kind: Option<String>,
510
511 #[serde(default)]
512 pub message: Option<String>,
513
514 #[serde(default)]
515 pub stacktrace: Option<String>,
516}
517
518#[derive(Debug, Clone, Default)]
520pub struct QueryOptions {
521 pub context: Option<OnlineQueryContext>,
522 pub staleness: Option<HashMap<String, String>>,
523 pub include_meta: Option<bool>,
524 pub query_name: Option<String>,
525 pub query_name_version: Option<String>,
526 pub correlation_id: Option<String>,
527 pub query_context: Option<HashMap<String, Value>>,
528 pub meta: Option<HashMap<String, String>>,
529 pub now: Option<String>,
530 pub explain: Option<bool>,
531 pub store_plan_stages: Option<bool>,
532 pub planner_options: Option<HashMap<String, Value>>,
533 pub branch_id: Option<String>,
534 pub encoding_options: Option<FeatureEncodingOptions>,
535}
536
537#[derive(Debug, Serialize)]
543pub struct TokenExchangeRequest {
544 pub client_id: String,
545 pub client_secret: String,
546 pub grant_type: String,
547}
548
549#[derive(Debug, Clone, Deserialize)]
551pub struct TokenResponse {
552 pub access_token: String,
553
554 #[serde(default)]
555 pub expires_at: Option<String>,
556
557 #[serde(default)]
558 pub expires_in: Option<i64>,
559
560 #[serde(default)]
561 pub primary_environment: Option<String>,
562
563 #[serde(default)]
564 pub engines: HashMap<String, String>,
565
566 #[serde(default)]
567 pub grpc_engines: HashMap<String, String>,
568
569 #[serde(default)]
570 pub environment_id_to_name: HashMap<String, String>,
571
572 #[serde(default)]
573 pub api_server: Option<String>,
574}
575
576#[cfg(test)]
580mod tests {
581 use super::*;
582
583 #[test]
584 fn test_online_query_request_serialization() {
585 let req = OnlineQueryRequest {
586 inputs: HashMap::from([("user.id".into(), serde_json::json!(1))]),
587 outputs: vec!["user.age".into(), "user.name".into()],
588 context: None,
589 staleness: None,
590 include_meta: Some(true),
591 query_name: None,
592 correlation_id: None,
593 query_context: None,
594 meta: None,
595 query_name_version: None,
596 now: None,
597 explain: None,
598 store_plan_stages: None,
599 encoding_options: None,
600 branch_id: None,
601 };
602
603 let json = serde_json::to_value(&req).unwrap();
604
605 assert_eq!(json["inputs"]["user.id"], 1);
606 assert_eq!(json["outputs"][0], "user.age");
607 assert_eq!(json["include_meta"], true);
608 assert!(json.get("context").is_none());
609 assert!(json.get("staleness").is_none());
610 assert!(json.get("query_name").is_none());
611 }
612
613 #[test]
614 fn test_online_query_response_deserialization() {
615 let json = r#"{
616 "data": [
617 {
618 "field": "user.age",
619 "value": 25,
620 "ts": "2024-01-15T10:30:00Z"
621 }
622 ],
623 "errors": [],
624 "meta": {
625 "execution_duration_s": 0.042,
626 "query_id": "q-123"
627 }
628 }"#;
629
630 let resp: OnlineQueryResponse = serde_json::from_str(json).unwrap();
631
632 assert_eq!(resp.data.len(), 1);
633 assert_eq!(resp.data[0].field, "user.age");
634 assert_eq!(resp.data[0].value, serde_json::json!(25));
635 assert_eq!(resp.data[0].ts.as_deref(), Some("2024-01-15T10:30:00Z"));
636 assert!(resp.data[0].meta.is_none());
637 assert!(resp.errors.is_empty());
638
639 let meta = resp.meta.unwrap();
640 assert_eq!(meta.execution_duration_s, Some(0.042));
641 assert_eq!(meta.query_id.as_deref(), Some("q-123"));
642 }
643
644 #[test]
645 fn test_chalk_error_round_trip() {
646 let err = ChalkError {
647 code: "RESOLVER_FAILED".into(),
648 category: "FIELD".into(),
649 message: "timeout after 30s".into(),
650 feature: Some("user.credit_score".into()),
651 resolver: Some("get_credit_score".into()),
652 exception: Some(ResolverException {
653 kind: Some("TimeoutError".into()),
654 message: Some("deadline exceeded".into()),
655 stacktrace: None,
656 }),
657 };
658
659 let json = serde_json::to_string(&err).unwrap();
660 let parsed: ChalkError = serde_json::from_str(&json).unwrap();
661
662 assert_eq!(parsed.code, "RESOLVER_FAILED");
663 assert_eq!(parsed.feature.as_deref(), Some("user.credit_score"));
664 assert!(parsed.exception.is_some());
665 assert_eq!(
666 parsed.exception.unwrap().kind.as_deref(),
667 Some("TimeoutError")
668 );
669 }
670
671 #[test]
672 fn test_token_response_deserialization() {
673 let json = r#"{
674 "access_token": "eyJhbGci...",
675 "expires_in": 3600,
676 "primary_environment": "env-123",
677 "engines": {
678 "env-123": "https://engine1.chalk.ai"
679 },
680 "grpc_engines": {
681 "env-123": "https://grpc1.chalk.ai"
682 }
683 }"#;
684
685 let resp: TokenResponse = serde_json::from_str(json).unwrap();
686 assert_eq!(resp.access_token, "eyJhbGci...");
687 assert_eq!(resp.primary_environment.as_deref(), Some("env-123"));
688 assert_eq!(
689 resp.engines.get("env-123").map(|s| s.as_str()),
690 Some("https://engine1.chalk.ai")
691 );
692 }
693
694 #[test]
695 fn test_offline_query_request_serialization() {
696 let req = OfflineQueryRequest {
697 input: Some(OfflineQueryInputType::Inline(OfflineQueryInput {
698 columns: vec!["user.id".into(), "user.signup_date".into()],
699 values: vec![
700 vec![serde_json::json!(1), serde_json::json!(2)],
701 vec![serde_json::json!("2024-01-01"), serde_json::json!("2024-02-01")],
702 ],
703 })),
704 output: vec!["user.ltv".into()],
705 destination_format: Some("PARQUET".into()),
706 job_id: None,
707 max_samples: None,
708 max_cache_age_secs: None,
709 observed_at_lower_bound: None,
710 observed_at_upper_bound: None,
711 dataset_name: Some("training_data_v2".into()),
712 branch: None,
713 recompute_features: None,
714 tags: None,
715 required_resolver_tags: None,
716 correlation_id: None,
717 store_online: None,
718 store_offline: None,
719 required_output: None,
720 run_asynchronously: None,
721 num_shards: None,
722 num_workers: None,
723 resources: None,
724 completion_deadline: None,
725 max_retries: None,
726 store_plan_stages: None,
727 explain: None,
728 planner_options: None,
729 query_context: None,
730 use_multiple_computers: None,
731 spine_sql_query: None,
732 query_name: None,
733 query_name_version: None,
734 };
735
736 let json = serde_json::to_value(&req).unwrap();
737 assert_eq!(json["output"][0], "user.ltv");
738 assert_eq!(json["input"]["columns"][0], "user.id");
739 assert_eq!(json["dataset_name"], "training_data_v2");
740 assert!(json.get("branch").is_none());
741 assert!(json.get("use_multiple_computers").is_none());
742 }
743
744 #[test]
745 fn test_offline_query_request_with_uri_input() {
746 let req = OfflineQueryRequest {
747 input: Some(OfflineQueryInputType::Uri(OfflineQueryInputUri {
748 parquet_uri: "s3://bucket/inputs.parquet".into(),
749 start_row: None,
750 end_row: None,
751 })),
752 output: vec!["user.ltv".into()],
753 destination_format: Some("PARQUET".into()),
754 job_id: None,
755 max_samples: None,
756 max_cache_age_secs: None,
757 observed_at_lower_bound: None,
758 observed_at_upper_bound: None,
759 dataset_name: None,
760 branch: None,
761 recompute_features: None,
762 tags: None,
763 required_resolver_tags: None,
764 correlation_id: None,
765 store_online: None,
766 store_offline: None,
767 required_output: None,
768 run_asynchronously: None,
769 num_shards: None,
770 num_workers: None,
771 resources: None,
772 completion_deadline: None,
773 max_retries: None,
774 store_plan_stages: None,
775 explain: None,
776 planner_options: None,
777 query_context: None,
778 use_multiple_computers: None,
779 spine_sql_query: None,
780 query_name: None,
781 query_name_version: None,
782 };
783
784 let json = serde_json::to_value(&req).unwrap();
785 assert_eq!(json["input"]["parquet_uri"], "s3://bucket/inputs.parquet");
786 assert!(json["input"].get("columns").is_none());
787 }
788}