use deadpool_redis::{Config as PoolConfig, Connection, Pool, Runtime};
use redis::AsyncCommands;
use serde_json::Value;
use crate::error::AppError;
use crate::models::*;
pub const ENTITY_TTL_SECS: usize = 60;
pub struct LuaScripts {
pub register_engine: String,
pub unregister_engine: String,
pub heartbeat: String,
}
impl LuaScripts {
pub fn load() -> Result<Self, AppError> {
let lua_dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("lua");
let read = |name: &str| -> Result<String, AppError> {
std::fs::read_to_string(lua_dir.join(name))
.map_err(|e| AppError::Internal(format!("Failed to load Lua script {name}: {e}")))
};
Ok(Self {
register_engine: read("register_engine.lua")?,
unregister_engine: read("unregister_engine.lua")?,
heartbeat: read("heartbeat_engine.lua")?,
})
}
}
#[derive(Clone)]
pub struct RedisRepo {
pool: Pool,
redis_url: String,
scripts: std::sync::Arc<LuaScripts>,
}
impl RedisRepo {
pub fn new(redis_url: &str, scripts: LuaScripts) -> Result<Self, AppError> {
let cfg = PoolConfig::from_url(redis_url);
let pool = cfg
.create_pool(Some(Runtime::Tokio1))
.map_err(|e| AppError::Internal(format!("Failed to create Redis pool: {e}")))?;
Ok(Self {
pool,
redis_url: redis_url.to_string(),
scripts: std::sync::Arc::new(scripts),
})
}
pub fn redis_url(&self) -> &str {
&self.redis_url
}
pub async fn conn(&self) -> Result<Connection, AppError> {
Ok(self.pool.get().await?)
}
pub fn scoped_key(&self, scope: Option<&str>, parts: &[&str]) -> String {
let suffix = parts.join(":");
match scope {
Some(s) if !s.is_empty() => format!("{s}:{suffix}"),
_ => suffix,
}
}
#[allow(clippy::too_many_arguments)]
pub async fn register_agent(
&self,
scope: Option<&str>,
alias: Option<String>,
name_prefix: &str,
device: &str,
ib_port: u32,
link_type: &str,
address: &str,
) -> Result<String, AppError> {
let mut conn = self.conn().await?;
let agent_name = if let Some(alias) = alias {
alias
} else {
let counter_key = self.scoped_key(scope, &["agent_name_counter"]);
let counter: i64 = redis::cmd("INCR")
.arg(&counter_key)
.query_async(&mut *conn)
.await?;
format!("{name_prefix}-{counter:x}")
};
let key = self.scoped_key(scope, &["agent", &agent_name]);
redis::cmd("HSET")
.arg(&key)
.arg("device")
.arg(device)
.arg("ib_port")
.arg(ib_port.to_string())
.arg("link_type")
.arg(link_type)
.arg("addr")
.arg(address)
.query_async::<()>(&mut *conn)
.await?;
redis::cmd("EXPIRE")
.arg(&key)
.arg(ENTITY_TTL_SECS)
.query_async::<()>(&mut *conn)
.await?;
tracing::info!("Registered peer agent: {agent_name} (TTL: {ENTITY_TTL_SECS}s)");
Ok(agent_name)
}
pub async fn list_agents(&self, scope: Option<&str>) -> Result<Vec<PeerAgent>, AppError> {
let mut conn = self.conn().await?;
let pattern = self.scoped_key(scope, &["agent:*"]);
let keys: Vec<String> = redis::cmd("KEYS")
.arg(&pattern)
.query_async(&mut *conn)
.await?;
let agent_prefix = self.scoped_key(scope, &["agent:"]);
let mut agents = Vec::new();
for key in keys {
let data: std::collections::HashMap<String, String> = redis::cmd("HGETALL")
.arg(&key)
.query_async(&mut *conn)
.await?;
if let (Some(dev), Some(ip)) = (data.get("device"), data.get("addr")) {
let name = key.strip_prefix(&agent_prefix).unwrap_or(&key).to_string();
agents.push(PeerAgent {
name,
device: dev.clone(),
ib_port: data
.get("ib_port")
.and_then(|x| x.parse().ok())
.unwrap_or(1),
link_type: data
.get("link_type")
.cloned()
.unwrap_or_else(|| "RoCE".into()),
address: ip.clone(),
});
}
}
Ok(agents)
}
pub async fn cleanup_agent(
&self,
scope: Option<&str>,
agent_name: &str,
) -> Result<(), AppError> {
let mut conn = self.conn().await?;
let mut peers_to_notify = std::collections::HashSet::new();
let spec_key = self.scoped_key(scope, &["spec:topology", agent_name]);
if let Ok(Some(spec_str)) = redis::cmd("GET")
.arg(&spec_key)
.query_async::<Option<String>>(&mut *conn)
.await
{
if let Ok(spec) = serde_json::from_str::<DesiredTopologySpec>(&spec_str) {
for p in spec.target_peers {
peers_to_notify.insert(p);
}
}
}
let spec_pattern = self.scoped_key(scope, &["spec:topology:*"]);
let spec_prefix = self.scoped_key(scope, &["spec:topology:"]);
let all_spec_keys: Vec<String> = redis::cmd("KEYS")
.arg(&spec_pattern)
.query_async(&mut *conn)
.await
.unwrap_or_default();
for key in &all_spec_keys {
if key == &spec_key {
continue;
}
if let Ok(Some(s)) = redis::cmd("GET")
.arg(key)
.query_async::<Option<String>>(&mut *conn)
.await
{
if let Ok(spec) = serde_json::from_str::<DesiredTopologySpec>(&s) {
if spec.target_peers.contains(&agent_name.to_string()) {
if let Some(other) = key.strip_prefix(&spec_prefix) {
peers_to_notify.insert(other.to_string());
}
}
}
}
}
let agent_key = self.scoped_key(scope, &["agent", agent_name]);
redis::cmd("DEL")
.arg(&agent_key)
.query_async::<()>(&mut *conn)
.await?;
redis::cmd("DEL")
.arg(&spec_key)
.query_async::<()>(&mut *conn)
.await
.ok();
for pattern_str in [
self.scoped_key(scope, &[&format!("exchange:{agent_name}:*")]),
self.scoped_key(scope, &[&format!("exchange:*:{agent_name}")]),
] {
let keys: Vec<String> = redis::cmd("KEYS")
.arg(&pattern_str)
.query_async(&mut *conn)
.await
.unwrap_or_default();
for k in &keys {
redis::cmd("DEL")
.arg(k)
.query_async::<()>(&mut *conn)
.await
.ok();
}
}
let inbox_key = self.scoped_key(scope, &["inbox", agent_name]);
redis::cmd("DEL")
.arg(&inbox_key)
.query_async::<()>(&mut *conn)
.await
.ok();
let mr_pattern = self.scoped_key(scope, &[&format!("mr:{agent_name}:*")]);
let mr_keys: Vec<String> = redis::cmd("KEYS")
.arg(&mr_pattern)
.query_async(&mut *conn)
.await
.unwrap_or_default();
if !mr_keys.is_empty() {
redis::cmd("DEL")
.arg(&mr_keys)
.query_async::<()>(&mut *conn)
.await
.ok();
}
for peer in &peers_to_notify {
let cleanup_event = serde_json::json!({
"type": "cleanup",
"peer": agent_name,
});
let peer_inbox = self.scoped_key(scope, &["inbox", peer]);
redis::cmd("LPUSH")
.arg(&peer_inbox)
.arg(cleanup_event.to_string())
.query_async::<()>(&mut *conn)
.await
.ok();
tracing::info!("Notified peer {peer} to clean up connection with {agent_name}");
}
tracing::info!("Cleanup completed for agent: {agent_name}");
Ok(())
}
pub async fn register_engine(
&self,
scope: Option<&str>,
body: &RegisterEngineBody,
) -> Result<i64, AppError> {
let mut conn = self.conn().await?;
let engine_key = self.scoped_key(scope, &["engine", &body.engine_id]);
let revision_key = self.scoped_key(scope, &["nano_meta:engine_revision"]);
let channel = self.scoped_key(scope, &["nano_events:engine_update"]);
let zmq_address = format!("tcp://{}:{}", body.host, body.port);
let payload = serde_json::json!({
"id": body.engine_id,
"role": body.role,
"host": body.host,
"port": body.port,
"zmq_address": zmq_address,
"world_size": body.world_size,
"num_blocks": body.num_blocks,
"peer_addrs": body.peer_addrs,
"model_path": body.model_path, });
let engine_info = serde_json::json!({
"id": body.engine_id,
"role": body.role,
"world_size": body.world_size,
"num_blocks": body.num_blocks,
"host": body.host,
"port": body.port,
"peer_addrs": body.peer_addrs,
"p2p_host": body.p2p_host.as_deref().unwrap_or_default(),
"p2p_port": body.p2p_port.unwrap_or(0),
"max_num_seqs": body.max_num_seqs.unwrap_or(0),
"model_path": body.model_path, });
let rev: i64 = redis::cmd("EVAL")
.arg(&*self.scripts.register_engine)
.arg(3) .arg(&engine_key)
.arg(&revision_key)
.arg(&channel)
.arg(&body.engine_id) .arg(&body.role) .arg(&body.host) .arg(body.port.to_string()) .arg(body.world_size.to_string()) .arg(body.num_blocks.to_string()) .arg(serde_json::to_string(&body.peer_addrs).unwrap_or_default()) .arg(engine_info.to_string()) .arg(payload.to_string()) .arg(ENTITY_TTL_SECS.to_string()) .arg(body.model_path.as_deref().unwrap_or("")) .query_async(&mut *conn)
.await?;
tracing::info!(
"Registered engine: {} (revision: {}, TTL: {}s)",
body.engine_id,
rev,
ENTITY_TTL_SECS
);
Ok(rev)
}
pub async fn unregister_engine(
&self,
scope: Option<&str>,
engine_id: &str,
) -> Result<i64, AppError> {
let mut conn = self.conn().await?;
let engine_key = self.scoped_key(scope, &["engine", engine_id]);
let revision_key = self.scoped_key(scope, &["nano_meta:engine_revision"]);
let channel = self.scoped_key(scope, &["nano_events:engine_update"]);
let rev: i64 = redis::cmd("EVAL")
.arg(&*self.scripts.unregister_engine)
.arg(3)
.arg(&engine_key)
.arg(&revision_key)
.arg(&channel)
.arg(engine_id)
.query_async(&mut *conn)
.await?;
if rev == 0 {
tracing::warn!("Unregister engine: {engine_id} not found (already removed or never registered), treating as success");
return Ok(0);
}
tracing::info!("Unregistered engine: {engine_id} (revision: {rev})");
Ok(rev)
}
pub async fn heartbeat(
&self,
entity_type: &str,
scope: Option<&str>,
entity_id: &str,
) -> Result<bool, AppError> {
let key = self.scoped_key(scope, &[entity_type, entity_id]);
self.refresh_ttl(&key, ENTITY_TTL_SECS).await
}
async fn refresh_ttl(&self, key: &str, ttl_secs: usize) -> Result<bool, AppError> {
let mut conn = self.conn().await?;
let result: i64 = redis::cmd("EVAL")
.arg(&*self.scripts.heartbeat)
.arg(1)
.arg(key)
.arg(ttl_secs.to_string())
.query_async(&mut *conn)
.await?;
Ok(result == 1)
}
pub async fn get_engine_info(
&self,
scope: Option<&str>,
engine_id: &str,
) -> Result<Option<Value>, AppError> {
let mut conn = self.conn().await?;
let key = self.scoped_key(scope, &["engine", engine_id]);
let info_str: Option<String> = redis::cmd("HGET")
.arg(&key)
.arg("info")
.query_async(&mut *conn)
.await?;
match info_str {
Some(s) => Ok(serde_json::from_str(&s).ok()),
None => Ok(None),
}
}
pub async fn list_engines(&self, scope: Option<&str>) -> Result<Vec<Value>, AppError> {
let mut conn = self.conn().await?;
let pattern = self.scoped_key(scope, &["engine:*"]);
let keys: Vec<String> = redis::cmd("KEYS")
.arg(&pattern)
.query_async(&mut *conn)
.await?;
let mut engines = Vec::new();
for key in keys {
let info_str: Option<String> = redis::cmd("HGET")
.arg(&key)
.arg("info")
.query_async(&mut *conn)
.await
.unwrap_or(None);
if let Some(s) = info_str {
if let Ok(v) = serde_json::from_str::<Value>(&s) {
engines.push(v);
}
}
}
Ok(engines)
}
pub async fn save_topology(
&self,
scope: Option<&str>,
agent_id: &str,
spec: &DesiredTopologySpec,
) -> Result<(), AppError> {
let mut conn = self.conn().await?;
let key = self.scoped_key(scope, &["spec:topology", agent_id]);
let spec_json = serde_json::to_string(spec)?;
conn.set::<_, _, ()>(&key, &spec_json).await?;
let stream_key = self.scoped_key(scope, &["stream", agent_id]);
for target_peer in &spec.target_peers {
self.push_stream_connect(&mut conn, &stream_key, target_peer)
.await;
}
if spec.symmetric {
for target_peer in &spec.target_peers {
let peer_key = self.scoped_key(scope, &["spec:topology", target_peer]);
let existing: Option<String> = conn.get(&peer_key).await.ok().flatten();
let mut peer_targets: Vec<String> = existing
.and_then(|s| serde_json::from_str::<DesiredTopologySpec>(&s).ok())
.map(|p| p.target_peers)
.unwrap_or_default();
if !peer_targets.contains(&agent_id.to_string()) {
peer_targets.push(agent_id.to_string());
}
let peer_spec = DesiredTopologySpec {
target_peers: peer_targets,
min_bw: None,
symmetric: false,
scope: spec.scope.clone(),
};
if let Ok(peer_json) = serde_json::to_string(&peer_spec) {
let _: Result<(), _> = conn.set(&peer_key, &peer_json).await;
}
let peer_stream = self.scoped_key(scope, &["stream", target_peer]);
self.push_stream_connect(&mut conn, &peer_stream, agent_id)
.await;
}
tracing::info!(
"Symmetric: merged {agent_id} into {} target peer(s) spec",
spec.target_peers.len()
);
}
tracing::info!(
"Saved desired topology for agent {agent_id}: {} peer(s)",
spec.target_peers.len()
);
Ok(())
}
pub async fn register_mr(
&self,
scope: Option<&str>,
body: &RegisterMrBody,
) -> Result<(), AppError> {
let mut conn = self.conn().await?;
let key = self.scoped_key(scope, &["mr", &body.agent_name, &body.mr_name]);
let mr_info = serde_json::json!({
"addr": body.addr,
"length": body.length,
"rkey": body.rkey,
"lkey": body.lkey,
});
conn.set::<_, _, ()>(&key, mr_info.to_string()).await?;
tracing::info!(
"Registered MR: {} for agent: {}",
body.mr_name,
body.agent_name
);
Ok(())
}
pub async fn get_mr_info(
&self,
scope: Option<&str>,
dst: &str,
mr_name: &str,
) -> Result<Option<MrInfo>, AppError> {
let mut conn = self.conn().await?;
let key = self.scoped_key(scope, &["mr", dst, mr_name]);
let mr_str: Option<String> = conn.get(&key).await?;
Ok(mr_str.and_then(|s| {
let v: Value = serde_json::from_str(&s).ok()?;
Some(MrInfo {
addr: v["addr"].as_u64()?,
length: v["length"].as_u64()? as usize,
rkey: v["rkey"].as_u64()? as u32,
lkey: v["lkey"].as_u64()? as u32,
})
}))
}
async fn push_stream_connect(&self, conn: &mut Connection, stream_key: &str, peer: &str) {
let timestamp = format!(
"{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs_f64()
);
let message = vec![
("type", "connect_peer"),
("peer", peer),
("timestamp", timestamp.as_str()),
];
if let Err(e) = redis::cmd("XADD")
.arg(stream_key)
.arg("MAXLEN")
.arg("~")
.arg("1000")
.arg("*")
.arg(&message)
.query_async::<String>(&mut **conn)
.await
{
tracing::warn!("Failed to push connect_peer to stream {stream_key}: {e}");
}
}
}