1use crate::query::plan::{
6 AggregateOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp, LogicalOperator,
7 NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp,
8};
9
10#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct Cost {
15 pub cpu: f64,
17 pub io: f64,
19 pub memory: f64,
21 pub network: f64,
23}
24
25impl Cost {
26 #[must_use]
28 pub fn zero() -> Self {
29 Self {
30 cpu: 0.0,
31 io: 0.0,
32 memory: 0.0,
33 network: 0.0,
34 }
35 }
36
37 #[must_use]
39 pub fn cpu(cpu: f64) -> Self {
40 Self {
41 cpu,
42 io: 0.0,
43 memory: 0.0,
44 network: 0.0,
45 }
46 }
47
48 #[must_use]
50 pub fn with_io(mut self, io: f64) -> Self {
51 self.io = io;
52 self
53 }
54
55 #[must_use]
57 pub fn with_memory(mut self, memory: f64) -> Self {
58 self.memory = memory;
59 self
60 }
61
62 #[must_use]
66 pub fn total(&self) -> f64 {
67 self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
68 }
69
70 #[must_use]
72 pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
73 self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
74 }
75}
76
77impl std::ops::Add for Cost {
78 type Output = Self;
79
80 fn add(self, other: Self) -> Self {
81 Self {
82 cpu: self.cpu + other.cpu,
83 io: self.io + other.io,
84 memory: self.memory + other.memory,
85 network: self.network + other.network,
86 }
87 }
88}
89
90impl std::ops::AddAssign for Cost {
91 fn add_assign(&mut self, other: Self) {
92 self.cpu += other.cpu;
93 self.io += other.io;
94 self.memory += other.memory;
95 self.network += other.network;
96 }
97}
98
99pub struct CostModel {
101 cpu_tuple_cost: f64,
103 #[allow(dead_code)]
105 io_page_cost: f64,
106 hash_lookup_cost: f64,
108 sort_comparison_cost: f64,
110 avg_tuple_size: f64,
112 page_size: f64,
114}
115
116impl CostModel {
117 #[must_use]
119 pub fn new() -> Self {
120 Self {
121 cpu_tuple_cost: 0.01,
122 io_page_cost: 1.0,
123 hash_lookup_cost: 0.02,
124 sort_comparison_cost: 0.02,
125 avg_tuple_size: 100.0,
126 page_size: 8192.0,
127 }
128 }
129
130 #[must_use]
132 pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
133 match op {
134 LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
135 LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
136 LogicalOperator::Project(project) => self.project_cost(project, cardinality),
137 LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
138 LogicalOperator::Join(join) => self.join_cost(join, cardinality),
139 LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
140 LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
141 LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
142 LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
143 LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
144 LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
145 LogicalOperator::Empty => Cost::zero(),
146 _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
147 }
148 }
149
150 fn node_scan_cost(&self, _scan: &NodeScanOp, cardinality: f64) -> Cost {
152 let pages = (cardinality * self.avg_tuple_size) / self.page_size;
153 Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
154 }
155
156 fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
158 Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
160 }
161
162 fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
164 let expr_count = project.projections.len() as f64;
166 Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
167 }
168
169 fn expand_cost(&self, _expand: &ExpandOp, cardinality: f64) -> Cost {
171 let lookup_cost = cardinality * self.hash_lookup_cost;
173 let avg_fanout = 10.0;
175 let output_cost = cardinality * avg_fanout * self.cpu_tuple_cost;
176 Cost::cpu(lookup_cost + output_cost)
177 }
178
179 fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
181 match join.join_type {
183 JoinType::Cross => {
184 Cost::cpu(cardinality * self.cpu_tuple_cost)
186 }
187 JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
188 let build_cardinality = cardinality.sqrt(); let probe_cardinality = cardinality.sqrt();
192
193 let build_cost = build_cardinality * self.hash_lookup_cost;
195 let memory_cost = build_cardinality * self.avg_tuple_size;
196
197 let probe_cost = probe_cardinality * self.hash_lookup_cost;
199
200 let output_cost = cardinality * self.cpu_tuple_cost;
202
203 Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
204 }
205 JoinType::Semi | JoinType::Anti => {
206 let build_cardinality = cardinality.sqrt();
208 let probe_cardinality = cardinality.sqrt();
209
210 let build_cost = build_cardinality * self.hash_lookup_cost;
211 let probe_cost = probe_cardinality * self.hash_lookup_cost;
212
213 Cost::cpu(build_cost + probe_cost)
214 .with_memory(build_cardinality * self.avg_tuple_size)
215 }
216 }
217 }
218
219 fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
221 let hash_cost = cardinality * self.hash_lookup_cost;
223
224 let agg_count = agg.aggregates.len() as f64;
226 let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
227
228 let distinct_groups = (cardinality / 10.0).max(1.0); let memory_cost = distinct_groups * self.avg_tuple_size;
231
232 Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
233 }
234
235 fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
237 if cardinality <= 1.0 {
238 return Cost::zero();
239 }
240
241 let comparisons = cardinality * cardinality.log2();
243 let key_count = sort.keys.len() as f64;
244
245 let memory_cost = cardinality * self.avg_tuple_size;
247
248 Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
249 }
250
251 fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
253 let hash_cost = cardinality * self.hash_lookup_cost;
255 let memory_cost = cardinality * self.avg_tuple_size * 0.5; Cost::cpu(hash_cost).with_memory(memory_cost)
258 }
259
260 fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
262 Cost::cpu(limit.count as f64 * self.cpu_tuple_cost * 0.1)
264 }
265
266 fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
268 Cost::cpu(skip.count as f64 * self.cpu_tuple_cost)
270 }
271
272 fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
274 let expr_count = ret.items.len() as f64;
276 Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
277 }
278
279 #[must_use]
281 pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
282 if a.total() <= b.total() { a } else { b }
283 }
284}
285
286impl Default for CostModel {
287 fn default() -> Self {
288 Self::new()
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use crate::query::plan::{
296 AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
297 Projection, ReturnItem, SortOrder,
298 };
299
300 #[test]
301 fn test_cost_addition() {
302 let a = Cost::cpu(10.0).with_io(5.0);
303 let b = Cost::cpu(20.0).with_memory(100.0);
304 let c = a + b;
305
306 assert!((c.cpu - 30.0).abs() < 0.001);
307 assert!((c.io - 5.0).abs() < 0.001);
308 assert!((c.memory - 100.0).abs() < 0.001);
309 }
310
311 #[test]
312 fn test_cost_total() {
313 let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
314 assert!((cost.total() - 30.0).abs() < 0.001);
316 }
317
318 #[test]
319 fn test_cost_model_node_scan() {
320 let model = CostModel::new();
321 let scan = NodeScanOp {
322 variable: "n".to_string(),
323 label: Some("Person".to_string()),
324 input: None,
325 };
326 let cost = model.node_scan_cost(&scan, 1000.0);
327
328 assert!(cost.cpu > 0.0);
329 assert!(cost.io > 0.0);
330 }
331
332 #[test]
333 fn test_cost_model_sort() {
334 let model = CostModel::new();
335 let sort = SortOp {
336 keys: vec![],
337 input: Box::new(LogicalOperator::Empty),
338 };
339
340 let cost_100 = model.sort_cost(&sort, 100.0);
341 let cost_1000 = model.sort_cost(&sort, 1000.0);
342
343 assert!(cost_1000.total() > cost_100.total());
345 }
346
347 #[test]
348 fn test_cost_zero() {
349 let cost = Cost::zero();
350 assert!((cost.cpu).abs() < 0.001);
351 assert!((cost.io).abs() < 0.001);
352 assert!((cost.memory).abs() < 0.001);
353 assert!((cost.network).abs() < 0.001);
354 assert!((cost.total()).abs() < 0.001);
355 }
356
357 #[test]
358 fn test_cost_add_assign() {
359 let mut cost = Cost::cpu(10.0);
360 cost += Cost::cpu(5.0).with_io(2.0);
361 assert!((cost.cpu - 15.0).abs() < 0.001);
362 assert!((cost.io - 2.0).abs() < 0.001);
363 }
364
365 #[test]
366 fn test_cost_total_weighted() {
367 let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
368 let total = cost.total_weighted(2.0, 5.0, 0.5);
370 assert!((total - 80.0).abs() < 0.001);
371 }
372
373 #[test]
374 fn test_cost_model_filter() {
375 let model = CostModel::new();
376 let filter = FilterOp {
377 predicate: LogicalExpression::Literal(graphos_common::types::Value::Bool(true)),
378 input: Box::new(LogicalOperator::Empty),
379 };
380 let cost = model.filter_cost(&filter, 1000.0);
381
382 assert!(cost.cpu > 0.0);
384 assert!((cost.io).abs() < 0.001);
385 }
386
387 #[test]
388 fn test_cost_model_project() {
389 let model = CostModel::new();
390 let project = ProjectOp {
391 projections: vec![
392 Projection {
393 expression: LogicalExpression::Variable("a".to_string()),
394 alias: None,
395 },
396 Projection {
397 expression: LogicalExpression::Variable("b".to_string()),
398 alias: None,
399 },
400 ],
401 input: Box::new(LogicalOperator::Empty),
402 };
403 let cost = model.project_cost(&project, 1000.0);
404
405 assert!(cost.cpu > 0.0);
407 }
408
409 #[test]
410 fn test_cost_model_expand() {
411 let model = CostModel::new();
412 let expand = ExpandOp {
413 from_variable: "a".to_string(),
414 to_variable: "b".to_string(),
415 edge_variable: None,
416 direction: ExpandDirection::Outgoing,
417 edge_type: None,
418 min_hops: 1,
419 max_hops: Some(1),
420 input: Box::new(LogicalOperator::Empty),
421 };
422 let cost = model.expand_cost(&expand, 1000.0);
423
424 assert!(cost.cpu > 0.0);
426 }
427
428 #[test]
429 fn test_cost_model_hash_join() {
430 let model = CostModel::new();
431 let join = JoinOp {
432 left: Box::new(LogicalOperator::Empty),
433 right: Box::new(LogicalOperator::Empty),
434 join_type: JoinType::Inner,
435 conditions: vec![JoinCondition {
436 left: LogicalExpression::Variable("a".to_string()),
437 right: LogicalExpression::Variable("b".to_string()),
438 }],
439 };
440 let cost = model.join_cost(&join, 10000.0);
441
442 assert!(cost.cpu > 0.0);
444 assert!(cost.memory > 0.0);
445 }
446
447 #[test]
448 fn test_cost_model_cross_join() {
449 let model = CostModel::new();
450 let join = JoinOp {
451 left: Box::new(LogicalOperator::Empty),
452 right: Box::new(LogicalOperator::Empty),
453 join_type: JoinType::Cross,
454 conditions: vec![],
455 };
456 let cost = model.join_cost(&join, 1000000.0);
457
458 assert!(cost.cpu > 0.0);
460 }
461
462 #[test]
463 fn test_cost_model_semi_join() {
464 let model = CostModel::new();
465 let join = JoinOp {
466 left: Box::new(LogicalOperator::Empty),
467 right: Box::new(LogicalOperator::Empty),
468 join_type: JoinType::Semi,
469 conditions: vec![],
470 };
471 let cost_semi = model.join_cost(&join, 1000.0);
472
473 let inner_join = JoinOp {
474 left: Box::new(LogicalOperator::Empty),
475 right: Box::new(LogicalOperator::Empty),
476 join_type: JoinType::Inner,
477 conditions: vec![],
478 };
479 let cost_inner = model.join_cost(&inner_join, 1000.0);
480
481 assert!(cost_semi.cpu > 0.0);
483 assert!(cost_inner.cpu > 0.0);
484 }
485
486 #[test]
487 fn test_cost_model_aggregate() {
488 let model = CostModel::new();
489 let agg = AggregateOp {
490 group_by: vec![],
491 aggregates: vec![
492 AggregateExpr {
493 function: AggregateFunction::Count,
494 expression: None,
495 distinct: false,
496 alias: Some("cnt".to_string()),
497 },
498 AggregateExpr {
499 function: AggregateFunction::Sum,
500 expression: Some(LogicalExpression::Variable("x".to_string())),
501 distinct: false,
502 alias: Some("total".to_string()),
503 },
504 ],
505 input: Box::new(LogicalOperator::Empty),
506 };
507 let cost = model.aggregate_cost(&agg, 1000.0);
508
509 assert!(cost.cpu > 0.0);
511 assert!(cost.memory > 0.0);
512 }
513
514 #[test]
515 fn test_cost_model_distinct() {
516 let model = CostModel::new();
517 let distinct = DistinctOp {
518 input: Box::new(LogicalOperator::Empty),
519 };
520 let cost = model.distinct_cost(&distinct, 1000.0);
521
522 assert!(cost.cpu > 0.0);
524 assert!(cost.memory > 0.0);
525 }
526
527 #[test]
528 fn test_cost_model_limit() {
529 let model = CostModel::new();
530 let limit = LimitOp {
531 count: 10,
532 input: Box::new(LogicalOperator::Empty),
533 };
534 let cost = model.limit_cost(&limit, 1000.0);
535
536 assert!(cost.cpu > 0.0);
538 assert!(cost.cpu < 1.0); }
540
541 #[test]
542 fn test_cost_model_skip() {
543 let model = CostModel::new();
544 let skip = SkipOp {
545 count: 100,
546 input: Box::new(LogicalOperator::Empty),
547 };
548 let cost = model.skip_cost(&skip, 1000.0);
549
550 assert!(cost.cpu > 0.0);
552 }
553
554 #[test]
555 fn test_cost_model_return() {
556 let model = CostModel::new();
557 let ret = ReturnOp {
558 items: vec![
559 ReturnItem {
560 expression: LogicalExpression::Variable("a".to_string()),
561 alias: None,
562 },
563 ReturnItem {
564 expression: LogicalExpression::Variable("b".to_string()),
565 alias: None,
566 },
567 ],
568 distinct: false,
569 input: Box::new(LogicalOperator::Empty),
570 };
571 let cost = model.return_cost(&ret, 1000.0);
572
573 assert!(cost.cpu > 0.0);
575 }
576
577 #[test]
578 fn test_cost_cheaper() {
579 let model = CostModel::new();
580 let cheap = Cost::cpu(10.0);
581 let expensive = Cost::cpu(100.0);
582
583 assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
584 assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
585 }
586
587 #[test]
588 fn test_cost_comparison_prefers_lower_total() {
589 let model = CostModel::new();
590 let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
592 let io_heavy = Cost::cpu(10.0).with_io(20.0);
594
595 assert!(cpu_heavy.total() < io_heavy.total());
597 assert_eq!(
598 model.cheaper(&cpu_heavy, &io_heavy).total(),
599 cpu_heavy.total()
600 );
601 }
602
603 #[test]
604 fn test_cost_model_sort_with_keys() {
605 let model = CostModel::new();
606 let sort_single = SortOp {
607 keys: vec![crate::query::plan::SortKey {
608 expression: LogicalExpression::Variable("a".to_string()),
609 order: SortOrder::Ascending,
610 }],
611 input: Box::new(LogicalOperator::Empty),
612 };
613 let sort_multi = SortOp {
614 keys: vec![
615 crate::query::plan::SortKey {
616 expression: LogicalExpression::Variable("a".to_string()),
617 order: SortOrder::Ascending,
618 },
619 crate::query::plan::SortKey {
620 expression: LogicalExpression::Variable("b".to_string()),
621 order: SortOrder::Descending,
622 },
623 ],
624 input: Box::new(LogicalOperator::Empty),
625 };
626
627 let cost_single = model.sort_cost(&sort_single, 1000.0);
628 let cost_multi = model.sort_cost(&sort_multi, 1000.0);
629
630 assert!(cost_multi.cpu > cost_single.cpu);
632 }
633
634 #[test]
635 fn test_cost_model_empty_operator() {
636 let model = CostModel::new();
637 let cost = model.estimate(&LogicalOperator::Empty, 0.0);
638 assert!((cost.total()).abs() < 0.001);
639 }
640
641 #[test]
642 fn test_cost_model_default() {
643 let model = CostModel::default();
644 let scan = NodeScanOp {
645 variable: "n".to_string(),
646 label: None,
647 input: None,
648 };
649 let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
650 assert!(cost.total() > 0.0);
651 }
652}