Skip to main content

chalk_client/
offline.rs

1//! Fluent builder for offline query parameters.
2//!
3//! [`OfflineQueryParams`] provides a chainable API for constructing offline
4//! queries. It supports three input modes:
5//!
6//! - **Inline data**: `OfflineQueryParams::new()` with `.with_input()` calls
7//! - **Parquet URI**: `OfflineQueryParams::from_uri("s3://...")`
8//! - **SQL query**: `OfflineQueryParams::from_sql("SELECT ...")`
9
10use std::collections::HashMap;
11
12use chrono::{DateTime, Utc};
13use serde_json::Value;
14
15use crate::error::{ChalkClientError, Result};
16use crate::types::{
17    OfflineQueryInput, OfflineQueryInputSql, OfflineQueryInputType, OfflineQueryInputUri,
18    OfflineQueryRequest, ResourceRequests,
19};
20
21/// The special column name Chalk uses for per-row observation timestamps.
22const CHALK_TS_COLUMN: &str = "__chalk__.CHALK_TS";
23
24/// Builder for offline query parameters.
25///
26/// # Examples
27///
28/// ```
29/// use chalk_client::OfflineQueryParams;
30/// use serde_json::json;
31///
32/// let params = OfflineQueryParams::new()
33///     .with_input("user.id", vec![json!(1), json!(2), json!(3)])
34///     .with_output("user.email")
35///     .with_output("user.ltv")
36///     .with_num_shards(4);
37/// ```
38#[derive(Debug, Clone)]
39pub struct OfflineQueryParams {
40    inputs: HashMap<String, Vec<Value>>,
41    input_times: Vec<DateTime<Utc>>,
42    input_type: Option<OfflineQueryInputType>,
43    output: Vec<String>,
44    required_output: Vec<String>,
45
46    destination_format: Option<String>,
47    job_id: Option<String>,
48    max_samples: Option<i64>,
49    max_cache_age_secs: Option<i64>,
50    observed_at_lower_bound: Option<String>,
51    observed_at_upper_bound: Option<String>,
52    dataset_name: Option<String>,
53    branch: Option<String>,
54    recompute_features: Option<Value>,
55    tags: Option<Vec<String>>,
56    required_resolver_tags: Option<Vec<String>>,
57    correlation_id: Option<String>,
58    store_online: Option<bool>,
59    store_offline: Option<bool>,
60    run_asynchronously: Option<bool>,
61    num_shards: Option<i64>,
62    num_workers: Option<i64>,
63    resources: Option<ResourceRequests>,
64    completion_deadline: Option<String>,
65    max_retries: Option<i64>,
66    store_plan_stages: Option<bool>,
67    explain: Option<bool>,
68    planner_options: Option<HashMap<String, Value>>,
69    query_context: Option<HashMap<String, Value>>,
70    spine_sql_query: Option<String>,
71    query_name: Option<String>,
72    query_name_version: Option<String>,
73}
74
75impl OfflineQueryParams {
76    /// Create a new builder for inline input data.
77    pub fn new() -> Self {
78        Self {
79            inputs: HashMap::new(),
80            input_times: Vec::new(),
81            input_type: None,
82            output: Vec::new(),
83            required_output: Vec::new(),
84            destination_format: None,
85            job_id: None,
86            max_samples: None,
87            max_cache_age_secs: None,
88            observed_at_lower_bound: None,
89            observed_at_upper_bound: None,
90            dataset_name: None,
91            branch: None,
92            recompute_features: None,
93            tags: None,
94            required_resolver_tags: None,
95            correlation_id: None,
96            store_online: None,
97            store_offline: None,
98            run_asynchronously: None,
99            num_shards: None,
100            num_workers: None,
101            resources: None,
102            completion_deadline: None,
103            max_retries: None,
104            store_plan_stages: None,
105            explain: None,
106            planner_options: None,
107            query_context: None,
108            spine_sql_query: None,
109            query_name: None,
110            query_name_version: None,
111        }
112    }
113
114    /// Create a builder that reads input from a Parquet file at the given URI.
115    pub fn from_uri(parquet_uri: impl Into<String>) -> Self {
116        let mut params = Self::new();
117        params.input_type = Some(OfflineQueryInputType::Uri(OfflineQueryInputUri {
118            parquet_uri: parquet_uri.into(),
119            start_row: None,
120            end_row: None,
121        }));
122        params
123    }
124
125    /// Create a builder that reads input from a Parquet URI with row range.
126    pub fn from_uri_with_range(
127        parquet_uri: impl Into<String>,
128        start_row: Option<i64>,
129        end_row: Option<i64>,
130    ) -> Self {
131        let mut params = Self::new();
132        params.input_type = Some(OfflineQueryInputType::Uri(OfflineQueryInputUri {
133            parquet_uri: parquet_uri.into(),
134            start_row,
135            end_row,
136        }));
137        params
138    }
139
140    /// Create a builder that generates input data from a SQL query.
141    pub fn from_sql(input_sql: impl Into<String>) -> Self {
142        let mut params = Self::new();
143        params.input_type = Some(OfflineQueryInputType::Sql(OfflineQueryInputSql {
144            input_sql: input_sql.into(),
145        }));
146        params
147    }
148
149    /// Add an input column with values.
150    pub fn with_input(mut self, feature: impl Into<String>, values: Vec<Value>) -> Self {
151        self.inputs.insert(feature.into(), values);
152        self
153    }
154
155    /// Set per-row observation timestamps.
156    pub fn with_input_times(mut self, times: Vec<DateTime<Utc>>) -> Self {
157        self.input_times = times;
158        self
159    }
160
161    /// Add a feature to the output list.
162    pub fn with_output(mut self, feature: impl Into<String>) -> Self {
163        self.output.push(feature.into());
164        self
165    }
166
167    /// Add a feature to the required output list.
168    pub fn with_required_output(mut self, feature: impl Into<String>) -> Self {
169        self.required_output.push(feature.into());
170        self
171    }
172
173    pub fn with_destination_format(mut self, format: impl Into<String>) -> Self {
174        self.destination_format = Some(format.into());
175        self
176    }
177
178    pub fn with_job_id(mut self, id: impl Into<String>) -> Self {
179        self.job_id = Some(id.into());
180        self
181    }
182
183    pub fn with_max_samples(mut self, n: i64) -> Self {
184        self.max_samples = Some(n);
185        self
186    }
187
188    pub fn with_max_cache_age_secs(mut self, secs: i64) -> Self {
189        self.max_cache_age_secs = Some(secs);
190        self
191    }
192
193    pub fn with_observed_at_lower_bound(mut self, bound: impl Into<String>) -> Self {
194        self.observed_at_lower_bound = Some(bound.into());
195        self
196    }
197
198    pub fn with_observed_at_upper_bound(mut self, bound: impl Into<String>) -> Self {
199        self.observed_at_upper_bound = Some(bound.into());
200        self
201    }
202
203    pub fn with_dataset_name(mut self, name: impl Into<String>) -> Self {
204        self.dataset_name = Some(name.into());
205        self
206    }
207
208    pub fn with_branch(mut self, branch: impl Into<String>) -> Self {
209        self.branch = Some(branch.into());
210        self
211    }
212
213    pub fn with_recompute_features(mut self, recompute: Value) -> Self {
214        self.recompute_features = Some(recompute);
215        self
216    }
217
218    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
219        self.tags = Some(tags);
220        self
221    }
222
223    pub fn with_required_resolver_tags(mut self, tags: Vec<String>) -> Self {
224        self.required_resolver_tags = Some(tags);
225        self
226    }
227
228    pub fn with_correlation_id(mut self, id: impl Into<String>) -> Self {
229        self.correlation_id = Some(id.into());
230        self
231    }
232
233    pub fn with_store_online(mut self, store: bool) -> Self {
234        self.store_online = Some(store);
235        self
236    }
237
238    pub fn with_store_offline(mut self, store: bool) -> Self {
239        self.store_offline = Some(store);
240        self
241    }
242
243    pub fn with_run_asynchronously(mut self, async_: bool) -> Self {
244        self.run_asynchronously = Some(async_);
245        self
246    }
247
248    pub fn with_num_shards(mut self, n: i64) -> Self {
249        self.num_shards = Some(n);
250        self
251    }
252
253    pub fn with_num_workers(mut self, n: i64) -> Self {
254        self.num_workers = Some(n);
255        self
256    }
257
258    pub fn with_resources(mut self, resources: ResourceRequests) -> Self {
259        self.resources = Some(resources);
260        self
261    }
262
263    pub fn with_completion_deadline(mut self, deadline: impl Into<String>) -> Self {
264        self.completion_deadline = Some(deadline.into());
265        self
266    }
267
268    pub fn with_max_retries(mut self, n: i64) -> Self {
269        self.max_retries = Some(n);
270        self
271    }
272
273    pub fn with_store_plan_stages(mut self, store: bool) -> Self {
274        self.store_plan_stages = Some(store);
275        self
276    }
277
278    pub fn with_explain(mut self, explain: bool) -> Self {
279        self.explain = Some(explain);
280        self
281    }
282
283    pub fn with_planner_options(mut self, options: HashMap<String, Value>) -> Self {
284        self.planner_options = Some(options);
285        self
286    }
287
288    pub fn with_query_context(mut self, context: HashMap<String, Value>) -> Self {
289        self.query_context = Some(context);
290        self
291    }
292
293    pub fn with_spine_sql_query(mut self, sql: impl Into<String>) -> Self {
294        self.spine_sql_query = Some(sql.into());
295        self
296    }
297
298    pub fn with_query_name(mut self, name: impl Into<String>) -> Self {
299        self.query_name = Some(name.into());
300        self
301    }
302
303    pub fn with_query_name_version(mut self, version: impl Into<String>) -> Self {
304        self.query_name_version = Some(version.into());
305        self
306    }
307
308    /// Build the [`OfflineQueryRequest`].
309    pub fn build(self) -> Result<OfflineQueryRequest> {
310        if self.output.is_empty() && self.required_output.is_empty() {
311            return Err(ChalkClientError::Config(
312                "offline query requires at least one output or required_output".into(),
313            ));
314        }
315
316        let spine_sql_query = match &self.input_type {
317            Some(OfflineQueryInputType::Sql(sql)) => Some(sql.input_sql.clone()),
318            _ => self.spine_sql_query,
319        };
320
321        let input = if let Some(input_type) = self.input_type {
322            match input_type {
323                OfflineQueryInputType::Inline(_)
324                | OfflineQueryInputType::Uri(_) => Some(input_type),
325                OfflineQueryInputType::Sql(_) => None,
326            }
327        } else if self.inputs.is_empty() {
328            None
329        } else {
330            let mut columns: Vec<String> = self.inputs.keys().cloned().collect();
331            columns.sort();
332
333            let mut values: Vec<Vec<Value>> = Vec::with_capacity(columns.len());
334            for col in &columns {
335                values.push(self.inputs[col].clone());
336            }
337
338            if !self.input_times.is_empty() {
339                columns.push(CHALK_TS_COLUMN.to_string());
340                let ts_values: Vec<Value> = self
341                    .input_times
342                    .iter()
343                    .map(|ts| Value::String(ts.to_rfc3339()))
344                    .collect();
345                values.push(ts_values);
346            }
347
348            Some(OfflineQueryInputType::Inline(OfflineQueryInput { columns, values }))
349        };
350
351        let required_output = if self.required_output.is_empty() {
352            None
353        } else {
354            Some(self.required_output)
355        };
356
357        let use_multiple_computers = if self.num_shards.is_some()
358            || self.num_workers.is_some()
359            || self.run_asynchronously == Some(true)
360        {
361            Some(true)
362        } else {
363            None
364        };
365
366        Ok(OfflineQueryRequest {
367            input,
368            output: self.output,
369            destination_format: self.destination_format,
370            job_id: self.job_id,
371            max_samples: self.max_samples,
372            max_cache_age_secs: self.max_cache_age_secs,
373            observed_at_lower_bound: self.observed_at_lower_bound,
374            observed_at_upper_bound: self.observed_at_upper_bound,
375            dataset_name: self.dataset_name,
376            branch: self.branch,
377            recompute_features: self.recompute_features,
378            tags: self.tags,
379            required_resolver_tags: self.required_resolver_tags,
380            correlation_id: self.correlation_id,
381            store_online: self.store_online,
382            store_offline: self.store_offline,
383            required_output,
384            run_asynchronously: self.run_asynchronously,
385            num_shards: self.num_shards,
386            num_workers: self.num_workers,
387            resources: self.resources,
388            completion_deadline: self.completion_deadline,
389            max_retries: self.max_retries,
390            store_plan_stages: self.store_plan_stages,
391            explain: self.explain,
392            planner_options: self.planner_options,
393            query_context: self.query_context,
394            use_multiple_computers,
395            spine_sql_query,
396            query_name: self.query_name,
397            query_name_version: self.query_name_version,
398        })
399    }
400}
401
402impl Default for OfflineQueryParams {
403    fn default() -> Self {
404        Self::new()
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411    use chrono::TimeZone;
412    use serde_json::json;
413
414    #[test]
415    fn test_builder_inline_input_serialization() {
416        let params = OfflineQueryParams::new()
417            .with_input("user.id", vec![json!(1), json!(2), json!(3)])
418            .with_output("user.email")
419            .with_output("user.ltv");
420
421        let req = params.build().unwrap();
422        let json = serde_json::to_value(&req).unwrap();
423
424        let input = &json["input"];
425        assert_eq!(input["columns"][0], "user.id");
426        assert_eq!(input["values"][0][0], 1);
427        assert_eq!(input["values"][0][1], 2);
428        assert_eq!(input["values"][0][2], 3);
429        assert_eq!(json["output"][0], "user.email");
430        assert_eq!(json["output"][1], "user.ltv");
431    }
432
433    #[test]
434    fn test_builder_with_timestamps() {
435        let ts1 = Utc.with_ymd_and_hms(2024, 1, 15, 10, 0, 0).unwrap();
436        let ts2 = Utc.with_ymd_and_hms(2024, 2, 15, 10, 0, 0).unwrap();
437
438        let params = OfflineQueryParams::new()
439            .with_input("user.id", vec![json!(1), json!(2)])
440            .with_input_times(vec![ts1, ts2])
441            .with_output("user.email");
442
443        let req = params.build().unwrap();
444        let input_type = req.input.unwrap();
445        let input = match input_type {
446            OfflineQueryInputType::Inline(inline) => inline,
447            _ => panic!("expected Inline input"),
448        };
449
450        assert!(input.columns.contains(&CHALK_TS_COLUMN.to_string()));
451        let ts_col_idx = input
452            .columns
453            .iter()
454            .position(|c| c == CHALK_TS_COLUMN)
455            .unwrap();
456        let ts_val = input.values[ts_col_idx][0].as_str().unwrap();
457        assert!(ts_val.contains("2024-01-15"));
458    }
459
460    #[test]
461    fn test_builder_validation_no_outputs() {
462        let params = OfflineQueryParams::new().with_input("user.id", vec![json!(1)]);
463
464        let result = params.build();
465        assert!(result.is_err());
466        let err = result.unwrap_err().to_string();
467        assert!(err.contains("at least one output"));
468    }
469
470    #[test]
471    fn test_builder_required_output_only() {
472        let params = OfflineQueryParams::new()
473            .with_input("user.id", vec![json!(1)])
474            .with_required_output("user.email");
475
476        let req = params.build().unwrap();
477        assert!(req.output.is_empty());
478        assert_eq!(req.required_output.as_ref().unwrap()[0], "user.email");
479    }
480
481    #[test]
482    fn test_from_uri() {
483        let params = OfflineQueryParams::from_uri("s3://bucket/inputs.parquet")
484            .with_output("user.email");
485
486        let req = params.build().unwrap();
487        let input = req.input.unwrap();
488        let json = serde_json::to_value(&input).unwrap();
489        assert_eq!(json["parquet_uri"], "s3://bucket/inputs.parquet");
490    }
491
492    #[test]
493    fn test_from_uri_serialization() {
494        let input = OfflineQueryInputUri {
495            parquet_uri: "s3://bucket/inputs.parquet".into(),
496            start_row: Some(0),
497            end_row: Some(1000),
498        };
499
500        let json = serde_json::to_value(&input).unwrap();
501        assert_eq!(json["parquet_uri"], "s3://bucket/inputs.parquet");
502        assert_eq!(json["start_row"], 0);
503        assert_eq!(json["end_row"], 1000);
504    }
505
506    #[test]
507    fn test_from_sql() {
508        let params =
509            OfflineQueryParams::from_sql("SELECT user_id FROM events").with_output("user.email");
510
511        let req = params.build().unwrap();
512        assert!(req.input.is_none());
513        assert_eq!(
514            req.spine_sql_query.as_deref(),
515            Some("SELECT user_id FROM events")
516        );
517    }
518
519    #[test]
520    fn test_from_sql_input_serialization() {
521        let input = OfflineQueryInputSql {
522            input_sql: "SELECT user_id FROM events".into(),
523        };
524
525        let json = serde_json::to_value(&input).unwrap();
526        assert_eq!(json["input_sql"], "SELECT user_id FROM events");
527    }
528
529    #[test]
530    fn test_builder_all_options() {
531        let params = OfflineQueryParams::new()
532            .with_input("user.id", vec![json!(1)])
533            .with_output("user.email")
534            .with_num_shards(4)
535            .with_num_workers(2)
536            .with_run_asynchronously(true)
537            .with_dataset_name("my_dataset")
538            .with_max_retries(3)
539            .with_completion_deadline("3600s");
540
541        let req = params.build().unwrap();
542        assert_eq!(req.num_shards, Some(4));
543        assert_eq!(req.num_workers, Some(2));
544        assert_eq!(req.run_asynchronously, Some(true));
545        assert_eq!(req.dataset_name.as_deref(), Some("my_dataset"));
546        assert_eq!(req.max_retries, Some(3));
547        assert_eq!(req.completion_deadline.as_deref(), Some("3600s"));
548        assert_eq!(req.use_multiple_computers, Some(true));
549    }
550
551    #[test]
552    fn test_use_multiple_computers_not_set_by_default() {
553        let params = OfflineQueryParams::new()
554            .with_input("user.id", vec![json!(1)])
555            .with_output("user.email");
556
557        let req = params.build().unwrap();
558        assert!(req.use_multiple_computers.is_none());
559    }
560}