use crate::asg::{Asg, AsgId, NodeId, Value};
use std::collections::HashMap;
use thiserror::Error;
#[derive(Error, Debug, Clone, PartialEq)]
pub enum RuntimeError {
#[error("Node with ID {0} not found in graph {1}. Verify that the graph was built correctly and all nodes exist.")]
NodeNotFound(NodeId, AsgId),
#[error("Graph with ID {0} not found in execution context. Ensure the graph was registered before execution.")]
GraphNotFound(AsgId),
#[error(
"Type mismatch: operation expected {expected}, but got {actual}. Check input data types."
)]
TypeError { expected: String, actual: String },
#[error("Tensor shape error: {0}. Check input tensor dimensions.")]
ShapeError(String),
#[error("Missing value for input '{0}' (node ID: {1}). Add this value to initial_data when calling backend.run().")]
MissingInput(String, NodeId),
#[error("Missing value for parameter '{0}' (node ID: {1}). Initialize the parameter before executing the graph.")]
MissingParameter(String, NodeId),
#[error("Operation '{0}' is not implemented in current backend. Consider using an alternative operation or implement support.")]
UnimplementedOperation(String),
#[error("Computation error: {0}")]
ComputationError(String),
#[error("Memory error: {0}")]
MemoryError(String),
}
pub type Memo<T> = HashMap<(AsgId, NodeId), T>;
pub trait Backend {
type DeviceData: std::fmt::Debug;
fn load_data(
&self,
data: &HashMap<String, Value>,
) -> Result<HashMap<String, Self::DeviceData>, RuntimeError>;
fn run(
&self,
main_asg: &Asg,
initial_memo: Memo<Self::DeviceData>,
) -> Result<(Vec<Self::DeviceData>, Memo<Self::DeviceData>), RuntimeError>;
fn retrieve_data(&self, device_data: &[Self::DeviceData]) -> Result<Vec<Value>, RuntimeError>;
}