#![cfg(feature = "inventory")]
mod common;
use crate::common::{EventLoggerDatabase, LogDatabase};
use expect_test::expect;
use salsa::{Database, Setter};
#[derive(Clone, Debug, Eq, PartialEq, Hash, salsa::Update)]
struct Graph<'db> {
nodes: Vec<Node<'db>>,
}
impl<'db> Graph<'db> {
fn find_node(&self, db: &dyn salsa::Database, name: &str) -> Option<Node<'db>> {
self.nodes
.iter()
.find(|node| node.name(db) == name)
.copied()
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
struct Edge {
to: usize,
cost: usize,
}
#[salsa::tracked(debug)]
struct Node<'db> {
#[returns(ref)]
name: String,
#[returns(deref)]
#[tracked]
edges: Vec<Edge>,
graph: GraphInput,
}
#[salsa::input(debug)]
struct GraphInput {
simple: bool,
fixpoint_variant: usize,
}
#[salsa::tracked(returns(ref))]
fn create_graph(db: &dyn salsa::Database, input: GraphInput) -> Graph<'_> {
if input.simple(db) {
let a = Node::new(db, "a".to_string(), vec![], input);
let b = Node::new(db, "b".to_string(), vec![Edge { to: 0, cost: 20 }], input);
let c = Node::new(db, "c".to_string(), vec![Edge { to: 1, cost: 2 }], input);
Graph {
nodes: vec![a, b, c],
}
} else {
let a = Node::new(db, "a".to_string(), vec![], input);
let b = Node::new(db, "b".to_string(), vec![Edge { to: 3, cost: 20 }], input);
let c = Node::new(db, "c".to_string(), vec![Edge { to: 3, cost: 4 }], input);
let d = Node::new(
db,
"d".to_string(),
vec![Edge { to: 0, cost: 4 }, Edge { to: 1, cost: 4 }],
input,
);
Graph {
nodes: vec![a, b, c, d],
}
}
}
#[salsa::tracked(cycle_initial=max_initial)]
fn cost_to_start<'db>(db: &'db dyn Database, node: Node<'db>) -> usize {
let mut min_cost = usize::MAX;
let graph = create_graph(db, node.graph(db));
for edge in node.edges(db) {
if edge.to == 0 {
min_cost = min_cost.min(edge.cost);
}
let edge_cost_to_start = cost_to_start(db, graph.nodes[edge.to]);
if edge_cost_to_start == usize::MAX {
continue;
}
min_cost = min_cost.min(edge.cost + edge_cost_to_start);
}
min_cost
}
fn max_initial(_db: &dyn Database, _id: salsa::Id, _node: Node) -> usize {
usize::MAX
}
#[test]
fn main() {
let mut db = EventLoggerDatabase::default();
let input = GraphInput::new(&db, false, 0);
let graph = create_graph(&db, input);
let c = graph.find_node(&db, "c").unwrap();
assert_eq!(cost_to_start(&db, c), 8);
input.set_simple(&mut db).to(true);
let graph = create_graph(&db, input);
let c = graph.find_node(&db, "c").unwrap();
assert_eq!(cost_to_start(&db, c), 22);
db.assert_logs(expect![[r#"
[
"WillCheckCancellation",
"WillExecute { database_key: create_graph(Id(0)) }",
"WillCheckCancellation",
"WillExecute { database_key: cost_to_start(Id(402)) }",
"WillCheckCancellation",
"WillCheckCancellation",
"WillExecute { database_key: cost_to_start(Id(403)) }",
"WillCheckCancellation",
"WillCheckCancellation",
"WillExecute { database_key: cost_to_start(Id(400)) }",
"WillCheckCancellation",
"WillCheckCancellation",
"WillExecute { database_key: cost_to_start(Id(401)) }",
"WillCheckCancellation",
"WillCheckCancellation",
"WillIterateCycle { database_key: cost_to_start(Id(403)), iteration_count: IterationCount(1) }",
"WillCheckCancellation",
"WillCheckCancellation",
"WillCheckCancellation",
"WillExecute { database_key: cost_to_start(Id(401)) }",
"WillCheckCancellation",
"WillCheckCancellation",
"DidFinalizeCycle { database_key: cost_to_start(Id(403)), iteration_count: IterationCount(1) }",
"DidSetCancellationFlag",
"WillCheckCancellation",
"WillExecute { database_key: create_graph(Id(0)) }",
"WillDiscardStaleOutput { execute_key: create_graph(Id(0)), output_key: Node(Id(403)) }",
"DidDiscard { key: Node(Id(403)) }",
"DidDiscard { key: cost_to_start(Id(403)) }",
"WillCheckCancellation",
"WillCheckCancellation",
"WillExecute { database_key: cost_to_start(Id(402)) }",
"WillCheckCancellation",
"WillCheckCancellation",
"WillExecute { database_key: cost_to_start(Id(401)) }",
"WillCheckCancellation",
"WillCheckCancellation",
"WillCheckCancellation",
"WillExecute { database_key: cost_to_start(Id(400)) }",
"WillCheckCancellation",
]"#]]);
}
#[salsa::tracked]
struct IterationNode<'db> {
#[returns(ref)]
name: String,
iteration: usize,
}
#[salsa::tracked(cycle_initial=initial_with_structs)]
fn create_tracked_in_cycle<'db>(
db: &'db dyn Database,
input: GraphInput,
) -> Vec<IterationNode<'db>> {
let variant = input.fixpoint_variant(db);
if variant == 0 {
vec![IterationNode::new(db, "base".to_string(), 0)]
} else {
let previous = create_tracked_in_cycle(db, input);
if previous.is_empty() {
vec![IterationNode::new(db, "iter_0".to_string(), 0)]
} else {
let limit = if variant == 1 { 3 } else { 2 };
if previous.len() < limit {
let mut nodes = previous;
nodes.push(IterationNode::new(
db,
format!("iter_{}", nodes.len()),
nodes.len(),
));
nodes
} else {
previous
}
}
}
}
fn initial_with_structs(
_db: &dyn Database,
_id: salsa::Id,
_input: GraphInput,
) -> Vec<IterationNode<'_>> {
vec![]
}
#[test_log::test]
fn test_cycle_with_fixpoint_structs() {
let mut db = EventLoggerDatabase::default();
let input = GraphInput::new(&db, false, 1);
let nodes = create_tracked_in_cycle(&db, input);
assert_eq!(nodes.len(), 3);
assert_eq!(nodes[0].name(&db), "iter_0");
assert_eq!(nodes[1].name(&db), "iter_1");
assert_eq!(nodes[2].name(&db), "iter_2");
db.clear_logs();
input.set_fixpoint_variant(&mut db).to(2);
let nodes = create_tracked_in_cycle(&db, input);
assert_eq!(nodes.len(), 2);
assert_eq!(nodes[0].name(&db), "iter_0");
assert_eq!(nodes[1].name(&db), "iter_1");
db.assert_logs(expect![[r#"
[
"DidSetCancellationFlag",
"WillCheckCancellation",
"WillExecute { database_key: create_tracked_in_cycle(Id(0)) }",
"WillCheckCancellation",
"WillIterateCycle { database_key: create_tracked_in_cycle(Id(0)), iteration_count: IterationCount(1) }",
"WillCheckCancellation",
"WillIterateCycle { database_key: create_tracked_in_cycle(Id(0)), iteration_count: IterationCount(2) }",
"WillCheckCancellation",
"DidFinalizeCycle { database_key: create_tracked_in_cycle(Id(0)), iteration_count: IterationCount(2) }",
"WillDiscardStaleOutput { execute_key: create_tracked_in_cycle(Id(0)), output_key: IterationNode(Id(402)) }",
"DidDiscard { key: IterationNode(Id(402)) }",
]"#]]);
}
#[salsa::tracked(debug)]
struct NameWithOffset<'db> {
name: String,
#[tracked]
offset: u32,
}
#[test]
fn cycle_tracked_struct_with_tracked_field() {
#[salsa::tracked(cycle_initial=|_,_| 0)]
fn query_a(db: &dyn salsa::Database) -> u32 {
let offset = query_b(db);
let tracked = NameWithOffset::new(db, "test".to_string(), offset);
tracked.offset(db)
}
#[salsa::tracked]
fn query_b(db: &dyn salsa::Database) -> u32 {
let base_offset = query_a(db);
(base_offset + 1).min(5)
}
let db = salsa::DatabaseImpl::default();
let result = query_a(&db);
assert_eq!(result, 5);
}