1use crate::query::plan::{
6 AggregateOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp, LogicalOperator,
7 NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp,
8};
9
10#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct Cost {
15 pub cpu: f64,
17 pub io: f64,
19 pub memory: f64,
21 pub network: f64,
23}
24
25impl Cost {
26 #[must_use]
28 pub fn zero() -> Self {
29 Self {
30 cpu: 0.0,
31 io: 0.0,
32 memory: 0.0,
33 network: 0.0,
34 }
35 }
36
37 #[must_use]
39 pub fn cpu(cpu: f64) -> Self {
40 Self {
41 cpu,
42 io: 0.0,
43 memory: 0.0,
44 network: 0.0,
45 }
46 }
47
48 #[must_use]
50 pub fn with_io(mut self, io: f64) -> Self {
51 self.io = io;
52 self
53 }
54
55 #[must_use]
57 pub fn with_memory(mut self, memory: f64) -> Self {
58 self.memory = memory;
59 self
60 }
61
62 #[must_use]
66 pub fn total(&self) -> f64 {
67 self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
68 }
69
70 #[must_use]
72 pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
73 self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
74 }
75}
76
77impl std::ops::Add for Cost {
78 type Output = Self;
79
80 fn add(self, other: Self) -> Self {
81 Self {
82 cpu: self.cpu + other.cpu,
83 io: self.io + other.io,
84 memory: self.memory + other.memory,
85 network: self.network + other.network,
86 }
87 }
88}
89
90impl std::ops::AddAssign for Cost {
91 fn add_assign(&mut self, other: Self) {
92 self.cpu += other.cpu;
93 self.io += other.io;
94 self.memory += other.memory;
95 self.network += other.network;
96 }
97}
98
99pub struct CostModel {
101 cpu_tuple_cost: f64,
103 #[allow(dead_code)]
105 io_page_cost: f64,
106 hash_lookup_cost: f64,
108 sort_comparison_cost: f64,
110 avg_tuple_size: f64,
112 page_size: f64,
114}
115
116impl CostModel {
117 #[must_use]
119 pub fn new() -> Self {
120 Self {
121 cpu_tuple_cost: 0.01,
122 io_page_cost: 1.0,
123 hash_lookup_cost: 0.02,
124 sort_comparison_cost: 0.02,
125 avg_tuple_size: 100.0,
126 page_size: 8192.0,
127 }
128 }
129
130 #[must_use]
132 pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
133 match op {
134 LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
135 LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
136 LogicalOperator::Project(project) => self.project_cost(project, cardinality),
137 LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
138 LogicalOperator::Join(join) => self.join_cost(join, cardinality),
139 LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
140 LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
141 LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
142 LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
143 LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
144 LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
145 LogicalOperator::Empty => Cost::zero(),
146 _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
147 }
148 }
149
150 fn node_scan_cost(&self, _scan: &NodeScanOp, cardinality: f64) -> Cost {
152 let pages = (cardinality * self.avg_tuple_size) / self.page_size;
153 Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
154 }
155
156 fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
158 Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
160 }
161
162 fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
164 let expr_count = project.projections.len() as f64;
166 Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
167 }
168
169 fn expand_cost(&self, _expand: &ExpandOp, cardinality: f64) -> Cost {
171 let lookup_cost = cardinality * self.hash_lookup_cost;
173 let avg_fanout = 10.0;
175 let output_cost = cardinality * avg_fanout * self.cpu_tuple_cost;
176 Cost::cpu(lookup_cost + output_cost)
177 }
178
179 fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
181 match join.join_type {
183 JoinType::Cross => {
184 Cost::cpu(cardinality * self.cpu_tuple_cost)
186 }
187 JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
188 let build_cardinality = cardinality.sqrt(); let probe_cardinality = cardinality.sqrt();
192
193 let build_cost = build_cardinality * self.hash_lookup_cost;
195 let memory_cost = build_cardinality * self.avg_tuple_size;
196
197 let probe_cost = probe_cardinality * self.hash_lookup_cost;
199
200 let output_cost = cardinality * self.cpu_tuple_cost;
202
203 Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
204 }
205 JoinType::Semi | JoinType::Anti => {
206 let build_cardinality = cardinality.sqrt();
208 let probe_cardinality = cardinality.sqrt();
209
210 let build_cost = build_cardinality * self.hash_lookup_cost;
211 let probe_cost = probe_cardinality * self.hash_lookup_cost;
212
213 Cost::cpu(build_cost + probe_cost)
214 .with_memory(build_cardinality * self.avg_tuple_size)
215 }
216 }
217 }
218
219 fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
221 let hash_cost = cardinality * self.hash_lookup_cost;
223
224 let agg_count = agg.aggregates.len() as f64;
226 let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
227
228 let distinct_groups = (cardinality / 10.0).max(1.0); let memory_cost = distinct_groups * self.avg_tuple_size;
231
232 Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
233 }
234
235 fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
237 if cardinality <= 1.0 {
238 return Cost::zero();
239 }
240
241 let comparisons = cardinality * cardinality.log2();
243 let key_count = sort.keys.len() as f64;
244
245 let memory_cost = cardinality * self.avg_tuple_size;
247
248 Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
249 }
250
251 fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
253 let hash_cost = cardinality * self.hash_lookup_cost;
255 let memory_cost = cardinality * self.avg_tuple_size * 0.5; Cost::cpu(hash_cost).with_memory(memory_cost)
258 }
259
260 fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
262 Cost::cpu(limit.count as f64 * self.cpu_tuple_cost * 0.1)
264 }
265
266 fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
268 Cost::cpu(skip.count as f64 * self.cpu_tuple_cost)
270 }
271
272 fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
274 let expr_count = ret.items.len() as f64;
276 Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
277 }
278
279 #[must_use]
281 pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
282 if a.total() <= b.total() { a } else { b }
283 }
284}
285
286impl Default for CostModel {
287 fn default() -> Self {
288 Self::new()
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use crate::query::plan::{
296 AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
297 Projection, ReturnItem, SortOrder,
298 };
299
300 #[test]
301 fn test_cost_addition() {
302 let a = Cost::cpu(10.0).with_io(5.0);
303 let b = Cost::cpu(20.0).with_memory(100.0);
304 let c = a + b;
305
306 assert!((c.cpu - 30.0).abs() < 0.001);
307 assert!((c.io - 5.0).abs() < 0.001);
308 assert!((c.memory - 100.0).abs() < 0.001);
309 }
310
311 #[test]
312 fn test_cost_total() {
313 let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
314 assert!((cost.total() - 30.0).abs() < 0.001);
316 }
317
318 #[test]
319 fn test_cost_model_node_scan() {
320 let model = CostModel::new();
321 let scan = NodeScanOp {
322 variable: "n".to_string(),
323 label: Some("Person".to_string()),
324 input: None,
325 };
326 let cost = model.node_scan_cost(&scan, 1000.0);
327
328 assert!(cost.cpu > 0.0);
329 assert!(cost.io > 0.0);
330 }
331
332 #[test]
333 fn test_cost_model_sort() {
334 let model = CostModel::new();
335 let sort = SortOp {
336 keys: vec![],
337 input: Box::new(LogicalOperator::Empty),
338 };
339
340 let cost_100 = model.sort_cost(&sort, 100.0);
341 let cost_1000 = model.sort_cost(&sort, 1000.0);
342
343 assert!(cost_1000.total() > cost_100.total());
345 }
346
347 #[test]
348 fn test_cost_zero() {
349 let cost = Cost::zero();
350 assert!((cost.cpu).abs() < 0.001);
351 assert!((cost.io).abs() < 0.001);
352 assert!((cost.memory).abs() < 0.001);
353 assert!((cost.network).abs() < 0.001);
354 assert!((cost.total()).abs() < 0.001);
355 }
356
357 #[test]
358 fn test_cost_add_assign() {
359 let mut cost = Cost::cpu(10.0);
360 cost += Cost::cpu(5.0).with_io(2.0);
361 assert!((cost.cpu - 15.0).abs() < 0.001);
362 assert!((cost.io - 2.0).abs() < 0.001);
363 }
364
365 #[test]
366 fn test_cost_total_weighted() {
367 let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
368 let total = cost.total_weighted(2.0, 5.0, 0.5);
370 assert!((total - 80.0).abs() < 0.001);
371 }
372
373 #[test]
374 fn test_cost_model_filter() {
375 let model = CostModel::new();
376 let filter = FilterOp {
377 predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
378 input: Box::new(LogicalOperator::Empty),
379 };
380 let cost = model.filter_cost(&filter, 1000.0);
381
382 assert!(cost.cpu > 0.0);
384 assert!((cost.io).abs() < 0.001);
385 }
386
387 #[test]
388 fn test_cost_model_project() {
389 let model = CostModel::new();
390 let project = ProjectOp {
391 projections: vec![
392 Projection {
393 expression: LogicalExpression::Variable("a".to_string()),
394 alias: None,
395 },
396 Projection {
397 expression: LogicalExpression::Variable("b".to_string()),
398 alias: None,
399 },
400 ],
401 input: Box::new(LogicalOperator::Empty),
402 };
403 let cost = model.project_cost(&project, 1000.0);
404
405 assert!(cost.cpu > 0.0);
407 }
408
409 #[test]
410 fn test_cost_model_expand() {
411 let model = CostModel::new();
412 let expand = ExpandOp {
413 from_variable: "a".to_string(),
414 to_variable: "b".to_string(),
415 edge_variable: None,
416 direction: ExpandDirection::Outgoing,
417 edge_type: None,
418 min_hops: 1,
419 max_hops: Some(1),
420 input: Box::new(LogicalOperator::Empty),
421 path_alias: None,
422 };
423 let cost = model.expand_cost(&expand, 1000.0);
424
425 assert!(cost.cpu > 0.0);
427 }
428
429 #[test]
430 fn test_cost_model_hash_join() {
431 let model = CostModel::new();
432 let join = JoinOp {
433 left: Box::new(LogicalOperator::Empty),
434 right: Box::new(LogicalOperator::Empty),
435 join_type: JoinType::Inner,
436 conditions: vec![JoinCondition {
437 left: LogicalExpression::Variable("a".to_string()),
438 right: LogicalExpression::Variable("b".to_string()),
439 }],
440 };
441 let cost = model.join_cost(&join, 10000.0);
442
443 assert!(cost.cpu > 0.0);
445 assert!(cost.memory > 0.0);
446 }
447
448 #[test]
449 fn test_cost_model_cross_join() {
450 let model = CostModel::new();
451 let join = JoinOp {
452 left: Box::new(LogicalOperator::Empty),
453 right: Box::new(LogicalOperator::Empty),
454 join_type: JoinType::Cross,
455 conditions: vec![],
456 };
457 let cost = model.join_cost(&join, 1000000.0);
458
459 assert!(cost.cpu > 0.0);
461 }
462
463 #[test]
464 fn test_cost_model_semi_join() {
465 let model = CostModel::new();
466 let join = JoinOp {
467 left: Box::new(LogicalOperator::Empty),
468 right: Box::new(LogicalOperator::Empty),
469 join_type: JoinType::Semi,
470 conditions: vec![],
471 };
472 let cost_semi = model.join_cost(&join, 1000.0);
473
474 let inner_join = JoinOp {
475 left: Box::new(LogicalOperator::Empty),
476 right: Box::new(LogicalOperator::Empty),
477 join_type: JoinType::Inner,
478 conditions: vec![],
479 };
480 let cost_inner = model.join_cost(&inner_join, 1000.0);
481
482 assert!(cost_semi.cpu > 0.0);
484 assert!(cost_inner.cpu > 0.0);
485 }
486
487 #[test]
488 fn test_cost_model_aggregate() {
489 let model = CostModel::new();
490 let agg = AggregateOp {
491 group_by: vec![],
492 aggregates: vec![
493 AggregateExpr {
494 function: AggregateFunction::Count,
495 expression: None,
496 distinct: false,
497 alias: Some("cnt".to_string()),
498 percentile: None,
499 },
500 AggregateExpr {
501 function: AggregateFunction::Sum,
502 expression: Some(LogicalExpression::Variable("x".to_string())),
503 distinct: false,
504 alias: Some("total".to_string()),
505 percentile: None,
506 },
507 ],
508 input: Box::new(LogicalOperator::Empty),
509 having: None,
510 };
511 let cost = model.aggregate_cost(&agg, 1000.0);
512
513 assert!(cost.cpu > 0.0);
515 assert!(cost.memory > 0.0);
516 }
517
518 #[test]
519 fn test_cost_model_distinct() {
520 let model = CostModel::new();
521 let distinct = DistinctOp {
522 input: Box::new(LogicalOperator::Empty),
523 columns: None,
524 };
525 let cost = model.distinct_cost(&distinct, 1000.0);
526
527 assert!(cost.cpu > 0.0);
529 assert!(cost.memory > 0.0);
530 }
531
532 #[test]
533 fn test_cost_model_limit() {
534 let model = CostModel::new();
535 let limit = LimitOp {
536 count: 10,
537 input: Box::new(LogicalOperator::Empty),
538 };
539 let cost = model.limit_cost(&limit, 1000.0);
540
541 assert!(cost.cpu > 0.0);
543 assert!(cost.cpu < 1.0); }
545
546 #[test]
547 fn test_cost_model_skip() {
548 let model = CostModel::new();
549 let skip = SkipOp {
550 count: 100,
551 input: Box::new(LogicalOperator::Empty),
552 };
553 let cost = model.skip_cost(&skip, 1000.0);
554
555 assert!(cost.cpu > 0.0);
557 }
558
559 #[test]
560 fn test_cost_model_return() {
561 let model = CostModel::new();
562 let ret = ReturnOp {
563 items: vec![
564 ReturnItem {
565 expression: LogicalExpression::Variable("a".to_string()),
566 alias: None,
567 },
568 ReturnItem {
569 expression: LogicalExpression::Variable("b".to_string()),
570 alias: None,
571 },
572 ],
573 distinct: false,
574 input: Box::new(LogicalOperator::Empty),
575 };
576 let cost = model.return_cost(&ret, 1000.0);
577
578 assert!(cost.cpu > 0.0);
580 }
581
582 #[test]
583 fn test_cost_cheaper() {
584 let model = CostModel::new();
585 let cheap = Cost::cpu(10.0);
586 let expensive = Cost::cpu(100.0);
587
588 assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
589 assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
590 }
591
592 #[test]
593 fn test_cost_comparison_prefers_lower_total() {
594 let model = CostModel::new();
595 let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
597 let io_heavy = Cost::cpu(10.0).with_io(20.0);
599
600 assert!(cpu_heavy.total() < io_heavy.total());
602 assert_eq!(
603 model.cheaper(&cpu_heavy, &io_heavy).total(),
604 cpu_heavy.total()
605 );
606 }
607
608 #[test]
609 fn test_cost_model_sort_with_keys() {
610 let model = CostModel::new();
611 let sort_single = SortOp {
612 keys: vec![crate::query::plan::SortKey {
613 expression: LogicalExpression::Variable("a".to_string()),
614 order: SortOrder::Ascending,
615 }],
616 input: Box::new(LogicalOperator::Empty),
617 };
618 let sort_multi = SortOp {
619 keys: vec![
620 crate::query::plan::SortKey {
621 expression: LogicalExpression::Variable("a".to_string()),
622 order: SortOrder::Ascending,
623 },
624 crate::query::plan::SortKey {
625 expression: LogicalExpression::Variable("b".to_string()),
626 order: SortOrder::Descending,
627 },
628 ],
629 input: Box::new(LogicalOperator::Empty),
630 };
631
632 let cost_single = model.sort_cost(&sort_single, 1000.0);
633 let cost_multi = model.sort_cost(&sort_multi, 1000.0);
634
635 assert!(cost_multi.cpu > cost_single.cpu);
637 }
638
639 #[test]
640 fn test_cost_model_empty_operator() {
641 let model = CostModel::new();
642 let cost = model.estimate(&LogicalOperator::Empty, 0.0);
643 assert!((cost.total()).abs() < 0.001);
644 }
645
646 #[test]
647 fn test_cost_model_default() {
648 let model = CostModel::default();
649 let scan = NodeScanOp {
650 variable: "n".to_string(),
651 label: None,
652 input: None,
653 };
654 let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
655 assert!(cost.total() > 0.0);
656 }
657}