use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock};
use crate::core::Expression;
use crate::engine::{ComputeEngine, ComputeError};
#[derive(Debug, Clone, PartialEq)]
pub enum LazyState {
Pending,
Computing,
Computed(Expression),
Failed(ComputeError),
}
#[derive(Clone)]
pub struct LazyExpression {
id: usize,
original: Expression,
state: Arc<RwLock<LazyState>>,
dependencies: Vec<Arc<LazyExpression>>,
compute_fn: Option<Arc<dyn Fn(&Expression, &dyn ComputeEngine) -> Result<Expression, ComputeError> + Send + Sync>>,
priority: i32,
}
impl std::fmt::Debug for LazyExpression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LazyExpression")
.field("id", &self.id)
.field("original", &self.original)
.field("state", &self.state)
.field("dependencies", &self.dependencies)
.field("compute_fn", &self.compute_fn.is_some())
.field("priority", &self.priority)
.finish()
}
}
impl LazyExpression {
pub fn new(id: usize, expr: Expression) -> Self {
Self {
id,
original: expr,
state: Arc::new(RwLock::new(LazyState::Pending)),
dependencies: Vec::new(),
compute_fn: None,
priority: 0,
}
}
pub fn with_compute_fn<F>(id: usize, expr: Expression, compute_fn: F) -> Self
where
F: Fn(&Expression, &dyn ComputeEngine) -> Result<Expression, ComputeError> + Send + Sync + 'static,
{
Self {
id,
original: expr,
state: Arc::new(RwLock::new(LazyState::Pending)),
dependencies: Vec::new(),
compute_fn: Some(Arc::new(compute_fn)),
priority: 0,
}
}
pub fn add_dependency(&mut self, dep: Arc<LazyExpression>) {
self.dependencies.push(dep);
}
pub fn set_priority(&mut self, priority: i32) {
self.priority = priority;
}
pub fn id(&self) -> usize {
self.id
}
pub fn original(&self) -> &Expression {
&self.original
}
pub fn state(&self) -> LazyState {
self.state.read().unwrap().clone()
}
pub fn dependencies(&self) -> &[Arc<LazyExpression>] {
&self.dependencies
}
pub fn priority(&self) -> i32 {
self.priority
}
pub fn is_computed(&self) -> bool {
matches!(self.state(), LazyState::Computed(_))
}
pub fn is_failed(&self) -> bool {
matches!(self.state(), LazyState::Failed(_))
}
pub fn is_computing(&self) -> bool {
matches!(self.state(), LazyState::Computing)
}
pub fn can_compute(&self) -> bool {
if !matches!(self.state(), LazyState::Pending) {
return false;
}
self.dependencies.iter().all(|dep| dep.is_computed())
}
pub fn get_result(&self) -> Option<Expression> {
match self.state() {
LazyState::Computed(expr) => Some(expr),
_ => None,
}
}
pub fn get_error(&self) -> Option<ComputeError> {
match self.state() {
LazyState::Failed(err) => Some(err),
_ => None,
}
}
pub fn force_compute(&self, engine: &dyn ComputeEngine) -> Result<Expression, ComputeError> {
match self.state() {
LazyState::Computed(expr) => return Ok(expr),
LazyState::Failed(err) => return Err(err),
LazyState::Computing => {
return Err(ComputeError::UnsupportedOperation {
operation: "循环依赖或重复计算".to_string()
});
}
LazyState::Pending => {}
}
{
let mut state = self.state.write().unwrap();
*state = LazyState::Computing;
}
for dep in &self.dependencies {
if let Err(err) = dep.force_compute(engine) {
let mut state = self.state.write().unwrap();
*state = LazyState::Failed(err.clone());
return Err(err);
}
}
let result = if let Some(ref compute_fn) = self.compute_fn {
compute_fn(&self.original, engine)
} else {
engine.simplify(&self.original)
};
match result {
Ok(expr) => {
let mut state = self.state.write().unwrap();
*state = LazyState::Computed(expr.clone());
Ok(expr)
}
Err(err) => {
let mut state = self.state.write().unwrap();
*state = LazyState::Failed(err.clone());
Err(err)
}
}
}
pub fn reset(&self) {
let mut state = self.state.write().unwrap();
*state = LazyState::Pending;
}
}
#[derive(Debug)]
pub struct DependencyGraph {
expressions: HashMap<usize, Arc<LazyExpression>>,
dependencies: HashMap<usize, HashSet<usize>>,
dependents: HashMap<usize, HashSet<usize>>,
next_id: usize,
}
impl DependencyGraph {
pub fn new() -> Self {
Self {
expressions: HashMap::new(),
dependencies: HashMap::new(),
dependents: HashMap::new(),
next_id: 1,
}
}
pub fn add_expression(&mut self, expr: Expression) -> Arc<LazyExpression> {
let id = self.next_id;
self.next_id += 1;
let lazy_expr = Arc::new(LazyExpression::new(id, expr));
self.expressions.insert(id, lazy_expr.clone());
self.dependencies.insert(id, HashSet::new());
self.dependents.insert(id, HashSet::new());
lazy_expr
}
pub fn add_expression_with_fn<F>(&mut self, expr: Expression, compute_fn: F) -> Arc<LazyExpression>
where
F: Fn(&Expression, &dyn ComputeEngine) -> Result<Expression, ComputeError> + Send + Sync + 'static,
{
let id = self.next_id;
self.next_id += 1;
let lazy_expr = Arc::new(LazyExpression::with_compute_fn(id, expr, compute_fn));
self.expressions.insert(id, lazy_expr.clone());
self.dependencies.insert(id, HashSet::new());
self.dependents.insert(id, HashSet::new());
lazy_expr
}
pub fn add_dependency(&mut self, expr_id: usize, dep_id: usize) -> Result<(), ComputeError> {
if self.would_create_cycle(expr_id, dep_id) {
return Err(ComputeError::UnsupportedOperation {
operation: "添加依赖会形成循环".to_string(),
});
}
self.dependencies.entry(expr_id).or_default().insert(dep_id);
self.dependents.entry(dep_id).or_default().insert(expr_id);
if let Some(_expr) = self.expressions.get(&expr_id) {
if let Some(_dep_expr) = self.expressions.get(&dep_id) {
}
}
Ok(())
}
fn would_create_cycle(&self, from: usize, to: usize) -> bool {
if from == to {
return true;
}
let mut visited = HashSet::new();
let mut stack = vec![to];
while let Some(current) = stack.pop() {
if visited.contains(¤t) {
continue;
}
visited.insert(current);
if current == from {
return true;
}
if let Some(deps) = self.dependencies.get(¤t) {
for &dep in deps {
if !visited.contains(&dep) {
stack.push(dep);
}
}
}
}
false
}
pub fn topological_sort(&self) -> Result<Vec<usize>, ComputeError> {
let mut in_degree = HashMap::new();
let mut result = Vec::new();
let mut queue = Vec::new();
for &id in self.expressions.keys() {
in_degree.insert(id, 0);
}
for (&expr_id, deps) in &self.dependencies {
let current_degree = in_degree.get(&expr_id).unwrap_or(&0);
in_degree.insert(expr_id, current_degree + deps.len());
}
for (&id, °ree) in &in_degree {
if degree == 0 {
queue.push(id);
}
}
while let Some(current) = queue.pop() {
result.push(current);
if let Some(dependents) = self.dependents.get(¤t) {
for &dependent in dependents {
if let Some(degree) = in_degree.get_mut(&dependent) {
*degree -= 1;
if *degree == 0 {
queue.push(dependent);
}
}
}
}
}
if result.len() != self.expressions.len() {
return Err(ComputeError::UnsupportedOperation {
operation: "检测到循环依赖".to_string(),
});
}
Ok(result)
}
pub fn get_parallel_groups(&self) -> Result<Vec<Vec<usize>>, ComputeError> {
let sorted = self.topological_sort()?;
let mut groups = Vec::new();
let mut computed = HashSet::new();
let mut remaining: HashSet<usize> = sorted.iter().cloned().collect();
while !remaining.is_empty() {
let mut current_group = Vec::new();
let ready_to_compute: Vec<usize> = remaining.iter()
.filter(|&&expr_id| {
if let Some(deps) = self.dependencies.get(&expr_id) {
deps.iter().all(|dep| computed.contains(dep))
} else {
true
}
})
.cloned()
.collect();
for expr_id in ready_to_compute {
current_group.push(expr_id);
computed.insert(expr_id);
remaining.remove(&expr_id);
}
if !current_group.is_empty() {
groups.push(current_group);
} else {
return Err(ComputeError::UnsupportedOperation {
operation: "检测到循环依赖或无法解析的依赖关系".to_string(),
});
}
}
Ok(groups)
}
pub fn get_expression(&self, id: usize) -> Option<Arc<LazyExpression>> {
self.expressions.get(&id).cloned()
}
pub fn get_all_expressions(&self) -> Vec<Arc<LazyExpression>> {
self.expressions.values().cloned().collect()
}
pub fn cleanup_completed(&mut self) {
let completed_ids: Vec<usize> = self.expressions
.iter()
.filter(|(_, expr)| expr.is_computed() || expr.is_failed())
.map(|(&id, _)| id)
.collect();
for id in completed_ids {
self.expressions.remove(&id);
self.dependencies.remove(&id);
self.dependents.remove(&id);
for deps in self.dependencies.values_mut() {
deps.remove(&id);
}
for dependents in self.dependents.values_mut() {
dependents.remove(&id);
}
}
}
pub fn reset_all(&self) {
for expr in self.expressions.values() {
expr.reset();
}
}
pub fn get_stats(&self) -> DependencyGraphStats {
let total = self.expressions.len();
let mut pending = 0;
let mut computing = 0;
let mut computed = 0;
let mut failed = 0;
for expr in self.expressions.values() {
match expr.state() {
LazyState::Pending => pending += 1,
LazyState::Computing => computing += 1,
LazyState::Computed(_) => computed += 1,
LazyState::Failed(_) => failed += 1,
}
}
DependencyGraphStats {
total_expressions: total,
pending_expressions: pending,
computing_expressions: computing,
computed_expressions: computed,
failed_expressions: failed,
total_dependencies: self.dependencies.values().map(|deps| deps.len()).sum(),
}
}
}
#[derive(Debug, Clone)]
pub struct DependencyGraphStats {
pub total_expressions: usize,
pub pending_expressions: usize,
pub computing_expressions: usize,
pub computed_expressions: usize,
pub failed_expressions: usize,
pub total_dependencies: usize,
}
impl Default for DependencyGraph {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::Expression;
use crate::engine::compute::BasicComputeEngine;
#[test]
fn test_lazy_expression_creation() {
let expr = Expression::variable("x");
let lazy_expr = LazyExpression::new(1, expr.clone());
assert_eq!(lazy_expr.id(), 1);
assert_eq!(lazy_expr.original(), &expr);
assert!(matches!(lazy_expr.state(), LazyState::Pending));
assert!(!lazy_expr.is_computed());
assert!(!lazy_expr.is_failed());
assert!(!lazy_expr.is_computing());
}
#[test]
fn test_dependency_graph() {
let mut graph = DependencyGraph::new();
let expr1 = graph.add_expression(Expression::variable("x"));
let expr2 = graph.add_expression(Expression::variable("y"));
let expr3 = graph.add_expression(Expression::add(
Expression::variable("x"),
Expression::variable("y")
));
graph.add_dependency(expr3.id(), expr1.id()).unwrap();
graph.add_dependency(expr3.id(), expr2.id()).unwrap();
let sorted = graph.topological_sort().unwrap();
assert_eq!(sorted.len(), 3);
let pos1 = sorted.iter().position(|&id| id == expr1.id()).unwrap();
let pos2 = sorted.iter().position(|&id| id == expr2.id()).unwrap();
let pos3 = sorted.iter().position(|&id| id == expr3.id()).unwrap();
assert!(pos1 < pos3);
assert!(pos2 < pos3);
}
#[test]
fn test_parallel_groups() {
let mut graph = DependencyGraph::new();
let expr_x = graph.add_expression(Expression::variable("x"));
let expr_y = graph.add_expression(Expression::variable("y"));
let expr_z = graph.add_expression(Expression::variable("z"));
let expr_sum = graph.add_expression(Expression::add(
Expression::variable("x"),
Expression::variable("y")
));
let expr_product = graph.add_expression(Expression::multiply(
Expression::add(Expression::variable("x"), Expression::variable("y")),
Expression::variable("z")
));
graph.add_dependency(expr_sum.id(), expr_x.id()).unwrap();
graph.add_dependency(expr_sum.id(), expr_y.id()).unwrap();
graph.add_dependency(expr_product.id(), expr_sum.id()).unwrap();
graph.add_dependency(expr_product.id(), expr_z.id()).unwrap();
let groups = graph.get_parallel_groups().unwrap();
assert_eq!(groups.len(), 3);
assert_eq!(groups[0].len(), 3); assert_eq!(groups[1].len(), 1); assert_eq!(groups[2].len(), 1); }
#[test]
fn test_cycle_detection() {
let mut graph = DependencyGraph::new();
let expr1 = graph.add_expression(Expression::variable("x"));
let expr2 = graph.add_expression(Expression::variable("y"));
graph.add_dependency(expr2.id(), expr1.id()).unwrap();
let result = graph.add_dependency(expr1.id(), expr2.id());
assert!(result.is_err());
}
#[test]
fn test_force_compute() {
let engine = BasicComputeEngine::new();
let expr = Expression::add(
Expression::number(2.into()),
Expression::number(3.into())
);
let lazy_expr = LazyExpression::new(1, expr);
let result = lazy_expr.force_compute(&engine).unwrap();
assert!(lazy_expr.is_computed());
assert_eq!(lazy_expr.get_result(), Some(result));
}
}