use crate::eval::Value;
use crate::diagnostics::{Error, Result};
use super::ConcurrencyError;
use std::sync::{Arc, Mutex};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::time::{Duration, Instant};
use tokio::net::{TcpListener, TcpStream};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use serde::{Serialize, Deserialize};
use uuid::Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct NodeId(Uuid);
impl NodeId {
pub fn new() -> Self {
Self(Uuid::new_v4())
}
pub fn from_uuid(uuid: Uuid) -> Self {
Self(uuid)
}
pub fn as_uuid(&self) -> Uuid {
self.0
}
}
impl std::fmt::Display for NodeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "node-{}", self.0)
}
}
impl Default for NodeId {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RpcRequest {
pub id: String,
pub service: String,
pub method: String,
pub args: Vec<SerializableValue>,
pub sender: NodeId,
pub timestamp: u64,
pub timeout: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RpcResponse {
pub request_id: String,
pub result: std::result::Result<SerializableValue, String>,
pub timestamp: u64,
pub processing_time: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SerializableValue {
Nil,
Boolean(bool),
Integer(i64),
Float(f64),
String(String),
Symbol(String),
List(Vec<SerializableValue>),
Vector(Vec<SerializableValue>),
Map(HashMap<String, SerializableValue>),
Bytes(Vec<u8>),
}
impl SerializableValue {
pub fn from_value(value: &Value) -> Result<Self> {
match value {
Value::Nil => Ok(SerializableValue::Nil),
Value::Literal(lit) => match lit {
crate::ast::Literal::Boolean(b) => Ok(SerializableValue::Boolean(*b)),
crate::ast::Literal::ExactInteger(i) => Ok(SerializableValue::Integer(*i)),
crate::ast::Literal::InexactReal(f) => Ok(SerializableValue::Float(*f)),
crate::ast::Literal::Number(f) => Ok(SerializableValue::Float(*f)),
crate::ast::Literal::Rational { numerator, denominator } => {
if *denominator == 1 {
Ok(SerializableValue::Integer(*numerator))
} else {
Ok(SerializableValue::Float(*numerator as f64 / *denominator as f64))
}
}
crate::ast::Literal::Complex { real, imaginary: _ } =>
Ok(SerializableValue::Float(*real)), crate::ast::Literal::String(s) => Ok(SerializableValue::String(s.clone())),
crate::ast::Literal::Character(c) => Ok(SerializableValue::String(c.to_string())),
crate::ast::Literal::Bytevector(bytes) =>
Ok(SerializableValue::String(format!("bytevector-{}", bytes.len()))),
crate::ast::Literal::Nil => Ok(SerializableValue::Nil),
crate::ast::Literal::Unspecified => Ok(SerializableValue::String("unspecified".to_string())),
},
Value::Symbol(sym) => Ok(SerializableValue::Symbol(format!("symbol-{}", sym.0))),
Value::Pair(_car, _cdr) => {
let mut list = Vec::new();
let mut current = value;
loop {
match current {
Value::Pair(car, cdr) => {
list.push(Self::from_value(car)?);
current = cdr;
}
Value::Nil => break,
_ => {
list.push(Self::from_value(current)?);
break;
}
}
}
Ok(SerializableValue::List(list))
}
Value::Vector(vec) => {
let guard = vec.read().unwrap();
let mut serializable_vec = Vec::new();
for item in guard.iter() {
serializable_vec.push(Self::from_value(item)?);
}
Ok(SerializableValue::Vector(serializable_vec))
}
_ => Err(Box::new(Error::runtime_error(
format!("Cannot serialize value type: {value:?}"),
None,
))),
}
}
pub fn to_value(&self) -> Result<Value> {
match self {
SerializableValue::Nil => Ok(Value::Nil),
SerializableValue::Boolean(b) => Ok(Value::Literal(crate::ast::Literal::Boolean(*b))),
SerializableValue::Integer(i) => Ok(Value::Literal(crate::ast::Literal::integer(*i))),
SerializableValue::Float(f) => Ok(Value::Literal(crate::ast::Literal::float(*f))),
SerializableValue::String(s) => Ok(Value::Literal(crate::ast::Literal::String(s.clone()))),
SerializableValue::Symbol(s) => {
Ok(Value::Symbol(crate::utils::SymbolId(s.len())))
}
SerializableValue::List(list) => {
let mut result = Value::Nil;
for item in list.iter().rev() {
let value = item.to_value()?;
result = Value::pair(value, result);
}
Ok(result)
}
SerializableValue::Vector(vec) => {
let mut values = Vec::new();
for item in vec {
values.push(item.to_value()?);
}
Ok(Value::Vector(Arc::new(std::sync::RwLock::new(values))))
}
SerializableValue::Map(_map) => {
Ok(Value::Nil) }
SerializableValue::Bytes(bytes) => {
Ok(Value::Literal(crate::ast::Literal::String(
String::from_utf8_lossy(bytes).to_string()
)))
}
}
}
}
#[async_trait::async_trait]
pub trait RpcService: Send + Sync + std::fmt::Debug {
async fn handle_request(&self, request: RpcRequest) -> RpcResponse;
fn service_name(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct RpcClient {
node_id: NodeId,
connections: Arc<Mutex<HashMap<NodeId, Arc<Connection>>>>,
}
impl RpcClient {
pub fn new(node_id: NodeId) -> Self {
Self {
node_id,
connections: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn connect(&self, node_id: NodeId, addr: SocketAddr) -> Result<()> {
let stream = TcpStream::connect(addr).await
.map_err(|e| ConcurrencyError::Network(e.to_string()).boxed())?;
let connection = Arc::new(Connection::new(stream));
{
let mut connections = self.connections.lock().unwrap();
connections.insert(node_id, connection);
}
Ok(())
}
pub async fn call(
&self,
target_node: NodeId,
service: String,
method: String,
args: Vec<Value>,
timeout: Option<Duration>,
) -> Result<Value> {
let connection = {
let connections = self.connections.lock().unwrap();
connections.get(&target_node)
.ok_or_else(|| ConcurrencyError::Network("Node not connected".to_string()).boxed())?
.clone()
};
let request_id = Uuid::new_v4().to_string();
let serializable_args: Result<Vec<_>> = args.iter()
.map(SerializableValue::from_value)
.collect();
let request = RpcRequest {
id: request_id.clone(),
service,
method,
args: serializable_args?,
sender: self.node_id,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
timeout: timeout.map(|t| t.as_millis() as u64),
};
let response = connection.send_request(request, timeout.unwrap_or(Duration::from_secs(30))).await?;
match response.result {
Ok(value) => value.to_value(),
Err(error) => Err(Error::runtime_error(error, None).boxed()),
}
}
}
#[derive(Debug)]
pub struct RpcServer {
node_id: NodeId,
listener: Option<TcpListener>,
services: Arc<Mutex<HashMap<String, Arc<dyn RpcService>>>>,
connections: Arc<Mutex<HashMap<NodeId, Arc<Connection>>>>,
}
impl RpcServer {
pub fn new(node_id: NodeId) -> Self {
Self {
node_id,
listener: None,
services: Arc::new(Mutex::new(HashMap::new())),
connections: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn bind(&mut self, addr: SocketAddr) -> Result<()> {
let listener = TcpListener::bind(addr).await
.map_err(|e| ConcurrencyError::Network(e.to_string()).boxed())?;
self.listener = Some(listener);
Ok(())
}
pub fn register_service(&self, service: Arc<dyn RpcService>) {
let mut services = self.services.lock().unwrap();
services.insert(service.service_name().to_string(), service);
}
pub async fn serve(&self) -> Result<()> {
let listener = self.listener.as_ref()
.ok_or_else(|| Error::runtime_error("Server not bound to address".to_string(), None))?;
loop {
let (stream, _addr) = listener.accept().await
.map_err(|e| ConcurrencyError::Network(e.to_string()).boxed())?;
let connection = Arc::new(Connection::new(stream));
let services = self.services.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_connection(connection, services).await {
eprintln!("Connection error: {e}");
}
});
}
}
async fn handle_connection(
connection: Arc<Connection>,
services: Arc<Mutex<HashMap<String, Arc<dyn RpcService>>>>,
) -> Result<()> {
loop {
match connection.receive_request().await {
Ok(request) => {
let services = services.clone();
let connection = connection.clone();
tokio::spawn(async move {
let response = Self::process_request(request, services).await;
if let Err(e) = connection.send_response(response).await {
eprintln!("Failed to send response: {e}");
}
});
}
Err(e) => {
eprintln!("Failed to receive request: {e}");
break;
}
}
}
Ok(())
}
async fn process_request(
request: RpcRequest,
services: Arc<Mutex<HashMap<String, Arc<dyn RpcService>>>>,
) -> RpcResponse {
let start_time = Instant::now();
let service = {
let services = services.lock().unwrap();
services.get(&request.service).cloned()
};
if let Some(service) = service {
service.handle_request(request.clone()).await
} else {
RpcResponse {
request_id: request.id,
result: Err(format!("Service '{service}' not found", service = request.service)),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
processing_time: Some(start_time.elapsed().as_micros() as u64),
}
}
}
}
struct Connection {
stream: Arc<tokio::sync::Mutex<TcpStream>>,
pending_requests: Arc<tokio::sync::Mutex<HashMap<String, tokio::sync::oneshot::Sender<RpcResponse>>>>,
}
impl std::fmt::Debug for Connection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Connection")
.field("stream", &"<TcpStream>")
.field("pending_requests", &"<PendingRequests>")
.finish()
}
}
impl Connection {
fn new(stream: TcpStream) -> Self {
Self {
stream: Arc::new(tokio::sync::Mutex::new(stream)),
pending_requests: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
}
}
async fn send_request(&self, request: RpcRequest, timeout: Duration) -> Result<RpcResponse> {
let (tx, rx) = tokio::sync::oneshot::channel();
{
let mut pending = self.pending_requests.lock().await;
pending.insert(request.id.clone(), tx);
}
let data = serde_json::to_vec(&request)
.map_err(|e| ConcurrencyError::Serialization(e.to_string()).boxed())?;
let mut stream = self.stream.lock().await;
stream.write_u32(data.len() as u32).await
.map_err(|e| ConcurrencyError::Network(e.to_string()).boxed())?;
stream.write_all(&data).await
.map_err(|e| ConcurrencyError::Network(e.to_string()).boxed())?;
tokio::time::timeout(timeout, rx)
.await
.map_err(|_| ConcurrencyError::Timeout.boxed())?
.map_err(|_| Error::runtime_error("Request cancelled".to_string(), None).boxed())
}
async fn receive_request(&self) -> Result<RpcRequest> {
let mut stream = self.stream.lock().await;
let len = stream.read_u32().await
.map_err(|e| ConcurrencyError::Network(e.to_string()).boxed())?;
let mut buffer = vec![0u8; len as usize];
stream.read_exact(&mut buffer).await
.map_err(|e| ConcurrencyError::Network(e.to_string()).boxed())?;
serde_json::from_slice(&buffer)
.map_err(|e| ConcurrencyError::Serialization(e.to_string()).boxed())
}
async fn send_response(&self, response: RpcResponse) -> Result<()> {
let data = serde_json::to_vec(&response)
.map_err(|e| ConcurrencyError::Serialization(e.to_string()).boxed())?;
let mut stream = self.stream.lock().await;
stream.write_u32(data.len() as u32).await
.map_err(|e| ConcurrencyError::Network(e.to_string()).boxed())?;
stream.write_all(&data).await
.map_err(|e| ConcurrencyError::Network(e.to_string()).boxed())?;
Ok(())
}
}
#[derive(Debug)]
pub struct CalculatorService;
#[async_trait::async_trait]
impl RpcService for CalculatorService {
async fn handle_request(&self, request: RpcRequest) -> RpcResponse {
let start_time = Instant::now();
let result = match request.method.as_str() {
"add" => {
if request.args.len() != 2 {
Err("add requires exactly 2 arguments".to_string())
} else {
match (&request.args[0], &request.args[1]) {
(SerializableValue::Integer(a), SerializableValue::Integer(b)) => {
Ok(SerializableValue::Integer(a + b))
}
(SerializableValue::Float(a), SerializableValue::Float(b)) => {
Ok(SerializableValue::Float(a + b))
}
_ => Err("add requires numeric arguments".to_string()),
}
}
}
"multiply" => {
if request.args.len() != 2 {
Err("multiply requires exactly 2 arguments".to_string())
} else {
match (&request.args[0], &request.args[1]) {
(SerializableValue::Integer(a), SerializableValue::Integer(b)) => {
Ok(SerializableValue::Integer(a * b))
}
(SerializableValue::Float(a), SerializableValue::Float(b)) => {
Ok(SerializableValue::Float(a * b))
}
_ => Err("multiply requires numeric arguments".to_string()),
}
}
}
_ => Err(format!("Unknown method: {}", request.method)),
};
RpcResponse {
request_id: request.id,
result,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
processing_time: Some(start_time.elapsed().as_micros() as u64),
}
}
fn service_name(&self) -> &str {
"calculator"
}
}
#[derive(Debug)]
pub struct DistributedNode {
id: NodeId,
rpc_client: RpcClient,
rpc_server: RpcServer,
}
impl DistributedNode {
pub fn new() -> Self {
let id = NodeId::new();
Self {
id,
rpc_client: RpcClient::new(id),
rpc_server: RpcServer::new(id),
}
}
pub fn id(&self) -> NodeId {
self.id
}
pub fn rpc_client(&self) -> &RpcClient {
&self.rpc_client
}
pub fn rpc_server(&mut self) -> &mut RpcServer {
&mut self.rpc_server
}
pub async fn start(&mut self, addr: SocketAddr) -> Result<()> {
self.rpc_server.bind(addr).await?;
println!("RPC server bound to {addr}");
Ok(())
}
}
impl Default for DistributedNode {
fn default() -> Self {
Self::new()
}
}
pub struct DistributedOps;
impl DistributedOps {
pub async fn distributed_map<F>(
nodes: Vec<NodeId>,
client: &RpcClient,
data: Vec<Value>,
_map_fn: F,
) -> Result<Vec<Value>>
where
F: Fn(&Value) -> Result<Value> + Send + Sync + 'static,
{
if nodes.is_empty() {
return Err(Box::new(Error::runtime_error("No nodes available".to_string(), None)))
}
let chunk_size = data.len().div_ceil(nodes.len());
let mut futures = Vec::new();
for (i, chunk) in data.chunks(chunk_size).enumerate() {
let node_id = nodes[i % nodes.len()];
let chunk_data = chunk.to_vec();
let future = client.call(
node_id,
"distributed".to_string(),
"map".to_string(),
chunk_data,
Some(Duration::from_secs(30)),
);
futures.push(future);
}
let mut results = Vec::new();
for future in futures {
let result = future.await?;
if let Value::Pair(_, _) = result {
let mut current = &result;
loop {
match current {
Value::Pair(car, cdr) => {
results.push((**car).clone());
current = cdr;
}
Value::Nil => break,
_ => {
results.push(current.clone());
break;
}
}
}
}
}
Ok(results)
}
pub async fn distributed_reduce(
nodes: Vec<NodeId>,
client: &RpcClient,
data: Vec<Value>,
identity: Value,
) -> Result<Value> {
let partial_results = Self::distributed_map(
nodes.clone(),
client,
data,
|_| Ok(Value::Nil), ).await?;
let mut result = identity;
for partial in partial_results {
result = partial; }
Ok(result)
}
}