1use std::collections::HashMap;
11
12use chrono::{DateTime, Utc};
13use serde_json::Value;
14
15use crate::error::{ChalkClientError, Result};
16use crate::types::{
17 OfflineQueryInput, OfflineQueryInputSql, OfflineQueryInputType, OfflineQueryInputUri,
18 OfflineQueryRequest, ResourceRequests,
19};
20
21const CHALK_TS_COLUMN: &str = "__chalk__.CHALK_TS";
23
24#[derive(Debug, Clone)]
39pub struct OfflineQueryParams {
40 inputs: HashMap<String, Vec<Value>>,
41 input_times: Vec<DateTime<Utc>>,
42 input_type: Option<OfflineQueryInputType>,
43 output: Vec<String>,
44 required_output: Vec<String>,
45
46 destination_format: Option<String>,
47 job_id: Option<String>,
48 max_samples: Option<i64>,
49 max_cache_age_secs: Option<i64>,
50 observed_at_lower_bound: Option<String>,
51 observed_at_upper_bound: Option<String>,
52 dataset_name: Option<String>,
53 branch: Option<String>,
54 recompute_features: Option<Value>,
55 tags: Option<Vec<String>>,
56 required_resolver_tags: Option<Vec<String>>,
57 correlation_id: Option<String>,
58 store_online: Option<bool>,
59 store_offline: Option<bool>,
60 run_asynchronously: Option<bool>,
61 num_shards: Option<i64>,
62 num_workers: Option<i64>,
63 resources: Option<ResourceRequests>,
64 completion_deadline: Option<String>,
65 max_retries: Option<i64>,
66 store_plan_stages: Option<bool>,
67 explain: Option<bool>,
68 planner_options: Option<HashMap<String, Value>>,
69 query_context: Option<HashMap<String, Value>>,
70 spine_sql_query: Option<String>,
71 query_name: Option<String>,
72 query_name_version: Option<String>,
73}
74
75impl OfflineQueryParams {
76 pub fn new() -> Self {
78 Self {
79 inputs: HashMap::new(),
80 input_times: Vec::new(),
81 input_type: None,
82 output: Vec::new(),
83 required_output: Vec::new(),
84 destination_format: None,
85 job_id: None,
86 max_samples: None,
87 max_cache_age_secs: None,
88 observed_at_lower_bound: None,
89 observed_at_upper_bound: None,
90 dataset_name: None,
91 branch: None,
92 recompute_features: None,
93 tags: None,
94 required_resolver_tags: None,
95 correlation_id: None,
96 store_online: None,
97 store_offline: None,
98 run_asynchronously: None,
99 num_shards: None,
100 num_workers: None,
101 resources: None,
102 completion_deadline: None,
103 max_retries: None,
104 store_plan_stages: None,
105 explain: None,
106 planner_options: None,
107 query_context: None,
108 spine_sql_query: None,
109 query_name: None,
110 query_name_version: None,
111 }
112 }
113
114 pub fn from_uri(parquet_uri: impl Into<String>) -> Self {
116 let mut params = Self::new();
117 params.input_type = Some(OfflineQueryInputType::Uri(OfflineQueryInputUri {
118 parquet_uri: parquet_uri.into(),
119 start_row: None,
120 end_row: None,
121 }));
122 params
123 }
124
125 pub fn from_uri_with_range(
127 parquet_uri: impl Into<String>,
128 start_row: Option<i64>,
129 end_row: Option<i64>,
130 ) -> Self {
131 let mut params = Self::new();
132 params.input_type = Some(OfflineQueryInputType::Uri(OfflineQueryInputUri {
133 parquet_uri: parquet_uri.into(),
134 start_row,
135 end_row,
136 }));
137 params
138 }
139
140 pub fn from_sql(input_sql: impl Into<String>) -> Self {
142 let mut params = Self::new();
143 params.input_type = Some(OfflineQueryInputType::Sql(OfflineQueryInputSql {
144 input_sql: input_sql.into(),
145 }));
146 params
147 }
148
149 pub fn with_input(mut self, feature: impl Into<String>, values: Vec<Value>) -> Self {
151 self.inputs.insert(feature.into(), values);
152 self
153 }
154
155 pub fn with_input_times(mut self, times: Vec<DateTime<Utc>>) -> Self {
157 self.input_times = times;
158 self
159 }
160
161 pub fn with_output(mut self, feature: impl Into<String>) -> Self {
163 self.output.push(feature.into());
164 self
165 }
166
167 pub fn with_required_output(mut self, feature: impl Into<String>) -> Self {
169 self.required_output.push(feature.into());
170 self
171 }
172
173 pub fn with_destination_format(mut self, format: impl Into<String>) -> Self {
174 self.destination_format = Some(format.into());
175 self
176 }
177
178 pub fn with_job_id(mut self, id: impl Into<String>) -> Self {
179 self.job_id = Some(id.into());
180 self
181 }
182
183 pub fn with_max_samples(mut self, n: i64) -> Self {
184 self.max_samples = Some(n);
185 self
186 }
187
188 pub fn with_max_cache_age_secs(mut self, secs: i64) -> Self {
189 self.max_cache_age_secs = Some(secs);
190 self
191 }
192
193 pub fn with_observed_at_lower_bound(mut self, bound: impl Into<String>) -> Self {
194 self.observed_at_lower_bound = Some(bound.into());
195 self
196 }
197
198 pub fn with_observed_at_upper_bound(mut self, bound: impl Into<String>) -> Self {
199 self.observed_at_upper_bound = Some(bound.into());
200 self
201 }
202
203 pub fn with_dataset_name(mut self, name: impl Into<String>) -> Self {
204 self.dataset_name = Some(name.into());
205 self
206 }
207
208 pub fn with_branch(mut self, branch: impl Into<String>) -> Self {
209 self.branch = Some(branch.into());
210 self
211 }
212
213 pub fn with_recompute_features(mut self, recompute: Value) -> Self {
214 self.recompute_features = Some(recompute);
215 self
216 }
217
218 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
219 self.tags = Some(tags);
220 self
221 }
222
223 pub fn with_required_resolver_tags(mut self, tags: Vec<String>) -> Self {
224 self.required_resolver_tags = Some(tags);
225 self
226 }
227
228 pub fn with_correlation_id(mut self, id: impl Into<String>) -> Self {
229 self.correlation_id = Some(id.into());
230 self
231 }
232
233 pub fn with_store_online(mut self, store: bool) -> Self {
234 self.store_online = Some(store);
235 self
236 }
237
238 pub fn with_store_offline(mut self, store: bool) -> Self {
239 self.store_offline = Some(store);
240 self
241 }
242
243 pub fn with_run_asynchronously(mut self, async_: bool) -> Self {
244 self.run_asynchronously = Some(async_);
245 self
246 }
247
248 pub fn with_num_shards(mut self, n: i64) -> Self {
249 self.num_shards = Some(n);
250 self
251 }
252
253 pub fn with_num_workers(mut self, n: i64) -> Self {
254 self.num_workers = Some(n);
255 self
256 }
257
258 pub fn with_resources(mut self, resources: ResourceRequests) -> Self {
259 self.resources = Some(resources);
260 self
261 }
262
263 pub fn with_completion_deadline(mut self, deadline: impl Into<String>) -> Self {
264 self.completion_deadline = Some(deadline.into());
265 self
266 }
267
268 pub fn with_max_retries(mut self, n: i64) -> Self {
269 self.max_retries = Some(n);
270 self
271 }
272
273 pub fn with_store_plan_stages(mut self, store: bool) -> Self {
274 self.store_plan_stages = Some(store);
275 self
276 }
277
278 pub fn with_explain(mut self, explain: bool) -> Self {
279 self.explain = Some(explain);
280 self
281 }
282
283 pub fn with_planner_options(mut self, options: HashMap<String, Value>) -> Self {
284 self.planner_options = Some(options);
285 self
286 }
287
288 pub fn with_query_context(mut self, context: HashMap<String, Value>) -> Self {
289 self.query_context = Some(context);
290 self
291 }
292
293 pub fn with_spine_sql_query(mut self, sql: impl Into<String>) -> Self {
294 self.spine_sql_query = Some(sql.into());
295 self
296 }
297
298 pub fn with_query_name(mut self, name: impl Into<String>) -> Self {
299 self.query_name = Some(name.into());
300 self
301 }
302
303 pub fn with_query_name_version(mut self, version: impl Into<String>) -> Self {
304 self.query_name_version = Some(version.into());
305 self
306 }
307
308 pub fn build(self) -> Result<OfflineQueryRequest> {
310 if self.output.is_empty() && self.required_output.is_empty() {
311 return Err(ChalkClientError::Config(
312 "offline query requires at least one output or required_output".into(),
313 ));
314 }
315
316 let spine_sql_query = match &self.input_type {
317 Some(OfflineQueryInputType::Sql(sql)) => Some(sql.input_sql.clone()),
318 _ => self.spine_sql_query,
319 };
320
321 let input = if let Some(input_type) = self.input_type {
322 match input_type {
323 OfflineQueryInputType::Inline(_)
324 | OfflineQueryInputType::Uri(_) => Some(input_type),
325 OfflineQueryInputType::Sql(_) => None,
326 }
327 } else if self.inputs.is_empty() {
328 None
329 } else {
330 let mut columns: Vec<String> = self.inputs.keys().cloned().collect();
331 columns.sort();
332
333 let mut values: Vec<Vec<Value>> = Vec::with_capacity(columns.len());
334 for col in &columns {
335 values.push(self.inputs[col].clone());
336 }
337
338 if !self.input_times.is_empty() {
339 columns.push(CHALK_TS_COLUMN.to_string());
340 let ts_values: Vec<Value> = self
341 .input_times
342 .iter()
343 .map(|ts| Value::String(ts.to_rfc3339()))
344 .collect();
345 values.push(ts_values);
346 }
347
348 Some(OfflineQueryInputType::Inline(OfflineQueryInput { columns, values }))
349 };
350
351 let required_output = if self.required_output.is_empty() {
352 None
353 } else {
354 Some(self.required_output)
355 };
356
357 let use_multiple_computers = if self.num_shards.is_some()
358 || self.num_workers.is_some()
359 || self.run_asynchronously == Some(true)
360 {
361 Some(true)
362 } else {
363 None
364 };
365
366 Ok(OfflineQueryRequest {
367 input,
368 output: self.output,
369 destination_format: self.destination_format,
370 job_id: self.job_id,
371 max_samples: self.max_samples,
372 max_cache_age_secs: self.max_cache_age_secs,
373 observed_at_lower_bound: self.observed_at_lower_bound,
374 observed_at_upper_bound: self.observed_at_upper_bound,
375 dataset_name: self.dataset_name,
376 branch: self.branch,
377 recompute_features: self.recompute_features,
378 tags: self.tags,
379 required_resolver_tags: self.required_resolver_tags,
380 correlation_id: self.correlation_id,
381 store_online: self.store_online,
382 store_offline: self.store_offline,
383 required_output,
384 run_asynchronously: self.run_asynchronously,
385 num_shards: self.num_shards,
386 num_workers: self.num_workers,
387 resources: self.resources,
388 completion_deadline: self.completion_deadline,
389 max_retries: self.max_retries,
390 store_plan_stages: self.store_plan_stages,
391 explain: self.explain,
392 planner_options: self.planner_options,
393 query_context: self.query_context,
394 use_multiple_computers,
395 spine_sql_query,
396 query_name: self.query_name,
397 query_name_version: self.query_name_version,
398 })
399 }
400}
401
402impl Default for OfflineQueryParams {
403 fn default() -> Self {
404 Self::new()
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411 use chrono::TimeZone;
412 use serde_json::json;
413
414 #[test]
415 fn test_builder_inline_input_serialization() {
416 let params = OfflineQueryParams::new()
417 .with_input("user.id", vec![json!(1), json!(2), json!(3)])
418 .with_output("user.email")
419 .with_output("user.ltv");
420
421 let req = params.build().unwrap();
422 let json = serde_json::to_value(&req).unwrap();
423
424 let input = &json["input"];
425 assert_eq!(input["columns"][0], "user.id");
426 assert_eq!(input["values"][0][0], 1);
427 assert_eq!(input["values"][0][1], 2);
428 assert_eq!(input["values"][0][2], 3);
429 assert_eq!(json["output"][0], "user.email");
430 assert_eq!(json["output"][1], "user.ltv");
431 }
432
433 #[test]
434 fn test_builder_with_timestamps() {
435 let ts1 = Utc.with_ymd_and_hms(2024, 1, 15, 10, 0, 0).unwrap();
436 let ts2 = Utc.with_ymd_and_hms(2024, 2, 15, 10, 0, 0).unwrap();
437
438 let params = OfflineQueryParams::new()
439 .with_input("user.id", vec![json!(1), json!(2)])
440 .with_input_times(vec![ts1, ts2])
441 .with_output("user.email");
442
443 let req = params.build().unwrap();
444 let input_type = req.input.unwrap();
445 let input = match input_type {
446 OfflineQueryInputType::Inline(inline) => inline,
447 _ => panic!("expected Inline input"),
448 };
449
450 assert!(input.columns.contains(&CHALK_TS_COLUMN.to_string()));
451 let ts_col_idx = input
452 .columns
453 .iter()
454 .position(|c| c == CHALK_TS_COLUMN)
455 .unwrap();
456 let ts_val = input.values[ts_col_idx][0].as_str().unwrap();
457 assert!(ts_val.contains("2024-01-15"));
458 }
459
460 #[test]
461 fn test_builder_validation_no_outputs() {
462 let params = OfflineQueryParams::new().with_input("user.id", vec![json!(1)]);
463
464 let result = params.build();
465 assert!(result.is_err());
466 let err = result.unwrap_err().to_string();
467 assert!(err.contains("at least one output"));
468 }
469
470 #[test]
471 fn test_builder_required_output_only() {
472 let params = OfflineQueryParams::new()
473 .with_input("user.id", vec![json!(1)])
474 .with_required_output("user.email");
475
476 let req = params.build().unwrap();
477 assert!(req.output.is_empty());
478 assert_eq!(req.required_output.as_ref().unwrap()[0], "user.email");
479 }
480
481 #[test]
482 fn test_from_uri() {
483 let params = OfflineQueryParams::from_uri("s3://bucket/inputs.parquet")
484 .with_output("user.email");
485
486 let req = params.build().unwrap();
487 let input = req.input.unwrap();
488 let json = serde_json::to_value(&input).unwrap();
489 assert_eq!(json["parquet_uri"], "s3://bucket/inputs.parquet");
490 }
491
492 #[test]
493 fn test_from_uri_serialization() {
494 let input = OfflineQueryInputUri {
495 parquet_uri: "s3://bucket/inputs.parquet".into(),
496 start_row: Some(0),
497 end_row: Some(1000),
498 };
499
500 let json = serde_json::to_value(&input).unwrap();
501 assert_eq!(json["parquet_uri"], "s3://bucket/inputs.parquet");
502 assert_eq!(json["start_row"], 0);
503 assert_eq!(json["end_row"], 1000);
504 }
505
506 #[test]
507 fn test_from_sql() {
508 let params =
509 OfflineQueryParams::from_sql("SELECT user_id FROM events").with_output("user.email");
510
511 let req = params.build().unwrap();
512 assert!(req.input.is_none());
513 assert_eq!(
514 req.spine_sql_query.as_deref(),
515 Some("SELECT user_id FROM events")
516 );
517 }
518
519 #[test]
520 fn test_from_sql_input_serialization() {
521 let input = OfflineQueryInputSql {
522 input_sql: "SELECT user_id FROM events".into(),
523 };
524
525 let json = serde_json::to_value(&input).unwrap();
526 assert_eq!(json["input_sql"], "SELECT user_id FROM events");
527 }
528
529 #[test]
530 fn test_builder_all_options() {
531 let params = OfflineQueryParams::new()
532 .with_input("user.id", vec![json!(1)])
533 .with_output("user.email")
534 .with_num_shards(4)
535 .with_num_workers(2)
536 .with_run_asynchronously(true)
537 .with_dataset_name("my_dataset")
538 .with_max_retries(3)
539 .with_completion_deadline("3600s");
540
541 let req = params.build().unwrap();
542 assert_eq!(req.num_shards, Some(4));
543 assert_eq!(req.num_workers, Some(2));
544 assert_eq!(req.run_asynchronously, Some(true));
545 assert_eq!(req.dataset_name.as_deref(), Some("my_dataset"));
546 assert_eq!(req.max_retries, Some(3));
547 assert_eq!(req.completion_deadline.as_deref(), Some("3600s"));
548 assert_eq!(req.use_multiple_computers, Some(true));
549 }
550
551 #[test]
552 fn test_use_multiple_computers_not_set_by_default() {
553 let params = OfflineQueryParams::new()
554 .with_input("user.id", vec![json!(1)])
555 .with_output("user.email");
556
557 let req = params.build().unwrap();
558 assert!(req.use_multiple_computers.is_none());
559 }
560}