mod activity;
mod mcp;
mod self_update;
mod tools;
use anyhow::{Context, Result};
use chrono::Utc;
use hyphae::Watchable;
use marshal_entities::{GetAllSessions, HostInfo, NotifyChannel, Session, SessionId};
use mcp::ServerConfig;
use myko::{
client::{ConnectionStatus, MykoClient},
wire::{MEvent, MEventType},
};
use std::sync::{Arc, Mutex};
use tokio::sync::mpsc;
use uuid::Uuid;
const DEFAULT_DAEMON_ADDRESS: &str = "ws://localhost:6155";
const ADDRESS_ENV: &str = "MARSHAL_DAEMON_ADDRESS";
const ADDRESS_ENV_LEGACY: &str = "MYKO_ADDRESS";
#[tokio::main]
async fn main() -> Result<()> {
let mut argv = std::env::args().skip(1);
if let Some(arg) = argv.next() {
if arg == "--check" && argv.next().is_none() {
println!("ok");
return Ok(());
}
}
init_logging();
marshal_entities::link();
let daemon_address = std::env::var(ADDRESS_ENV)
.or_else(|_| std::env::var(ADDRESS_ENV_LEGACY))
.unwrap_or_else(|_| DEFAULT_DAEMON_ADDRESS.to_string());
log::info!("[marshal-shim] connecting to {daemon_address}");
let client = Arc::new(MykoClient::new());
let (notify_tx, notify_rx) = mpsc::unbounded_channel::<NotifyChannel>();
let notify_tx_oncmd = notify_tx.clone();
let notify_guard = client.on_command::<NotifyChannel, _>(move |cmd, _responder| {
let _ = notify_tx_oncmd.send(cmd);
});
Box::leak(Box::new(notify_guard));
let cwd = std::env::current_dir()
.context("getting cwd")?
.display()
.to_string();
let pid = std::process::id();
let nickname = std::path::Path::new(&cwd)
.file_name()
.and_then(|s| s.to_str())
.filter(|s| !s.is_empty())
.unwrap_or("session")
.to_string();
let git_branch = detect_git_branch(&cwd);
let project = detect_project_basename(&cwd);
let operator = detect_operator();
let host = detect_host();
let session_id = SessionId(Arc::from(Uuid::new_v4().to_string()));
let session = Session {
id: session_id.clone(),
client_id: None,
nickname: nickname.clone(),
pid,
cwd: cwd.clone(),
git_branch: git_branch.clone(),
current_task: None,
connected_at: Utc::now().timestamp_millis(),
last_activity_at: None,
last_tool: None,
last_tool_at: None,
operator: Some(operator.clone()),
host: Some(host.clone()),
project: project.clone(),
};
let session = Arc::new(Mutex::new(session));
let sessions_cell = client.watch_query::<GetAllSessions>(GetAllSessions {});
let rooms_cell =
client.watch_query::<marshal_entities::GetAllRooms>(marshal_entities::GetAllRooms {});
let members_cell = client
.watch_query::<marshal_entities::GetAllRoomMembers>(marshal_entities::GetAllRoomMembers {});
let session_for_resend = Arc::clone(&session);
let client_for_resend = Arc::clone(&client);
let conn_guard = client.connection_status().subscribe(move |signal| {
if let hyphae::Signal::Value(status) = signal {
match &**status {
ConnectionStatus::Connected(addr) => {
log::info!("[marshal-shim] connected to {addr} — (re)sending session");
let snapshot = session_for_resend.lock().unwrap().clone();
if let Err(e) = emit_session_set(&client_for_resend, &snapshot) {
log::warn!("[marshal-shim] re-SET on connect failed: {e}");
}
}
ConnectionStatus::Disconnected => {
log::warn!("[marshal-shim] disconnected");
}
_ => {}
}
}
});
client.connection_status().own(conn_guard);
client.set_address(Some(daemon_address));
let host = Arc::new(tools::ToolHost {
client: Arc::clone(&client),
session_id: session_id.clone(),
nickname: nickname.clone(),
pid,
cwd: cwd.clone(),
session: Arc::clone(&session),
sessions_cell,
rooms_cell,
members_cell,
});
let handler = Arc::new(tools::CoordHandler { host });
let config = ServerConfig {
name: "marshal-shim".into(),
version: env!("CARGO_PKG_VERSION").into(),
instructions: format!(
"You are session '{nickname}' (id {}) in {cwd}. Coordinate with \
sibling Claude sessions via the marshal daemon.\n\
\n\
READ paths are resources (use `resources/read`):\n\
- marshal://whoami — your session id, nickname, pid, cwd, operator, host\n\
- marshal://roster — every live session and what room(s) it's in\n\
- marshal://rooms — every room and who its members are\n\
- marshal://messages — message history; supports query params:\n\
inbox=true, sent=true, unread=true,\n\
room=ID, from=SID, to_session=SID,\n\
since=MILLIS, limit=N\n\
\n\
WRITE paths are tools (use `tools/call`):\n\
- send_message — direct send to a peer's session_id\n\
- broadcast — fan-out to all members of a room\n\
- join_room — create or join an ad-hoc room\n\
- leave_room — leave an ad-hoc room\n\
- set_status — set this session's free-form status text\n\
- ack_messages — mark message ids as read for this session\n\
\n\
Inbound peer messages arrive as `notifications/claude/channel` \
events; reply with `send_message` or `broadcast`.",
session_id.0
),
tools: tools::tools_def(),
resources: tools::resources_def(),
};
let activity = Arc::new(activity::Activity::new());
self_update::spawn(Arc::clone(&activity));
let activity_for_publish = Arc::clone(&activity);
let client_for_publish = Arc::clone(&client);
let session_for_publish = Arc::clone(&session);
let session_id_for_publish = session_id.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(5));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
let mut pushed_activity_at: Option<i64> = None;
let mut pushed_tool: Option<String> = None;
let mut pushed_tool_at: Option<i64> = None;
loop {
interval.tick().await;
let last_activity_at = activity_for_publish.last_activity_ms();
let last_activity_at = if last_activity_at > 0 {
Some(last_activity_at)
} else {
None
};
let last_tool = activity_for_publish.last_tool_name();
let last_tool_at = activity_for_publish.last_tool_ms();
let last_tool_at = if last_tool_at > 0 {
Some(last_tool_at)
} else {
None
};
if pushed_activity_at != last_activity_at {
let _ = client_for_publish
.send_command::<marshal_entities::SetSessionLastActivityAt, ()>(
&marshal_entities::SetSessionLastActivityAt {
id: session_id_for_publish.clone(),
last_activity_at,
},
);
pushed_activity_at = last_activity_at;
}
if pushed_tool != last_tool {
let arc_tool = last_tool.as_deref().map(Arc::<str>::from);
let _ = client_for_publish
.send_command::<marshal_entities::SetSessionLastTool, ()>(
&marshal_entities::SetSessionLastTool {
id: session_id_for_publish.clone(),
last_tool: arc_tool,
},
);
pushed_tool = last_tool.clone();
}
if pushed_tool_at != last_tool_at {
let _ = client_for_publish
.send_command::<marshal_entities::SetSessionLastToolAt, ()>(
&marshal_entities::SetSessionLastToolAt {
id: session_id_for_publish.clone(),
last_tool_at,
},
);
pushed_tool_at = last_tool_at;
}
if let Ok(mut sess) = session_for_publish.lock() {
sess.last_activity_at = last_activity_at;
sess.last_tool = last_tool.clone();
sess.last_tool_at = last_tool_at;
}
}
});
let notify_rx = Mutex::new(Some(notify_rx));
mcp::serve_stdio(config, handler, Arc::clone(&activity), move |notifier| {
if let Some(mut rx) = notify_rx.lock().ok().and_then(|mut g| g.take()) {
tokio::spawn(async move {
while let Some(cmd) = rx.recv().await {
notifier.channel(cmd.content, cmd.meta);
}
});
log::info!("[marshal-shim] notification drain task started");
}
})
.await
}
fn emit_session_set(client: &MykoClient, session: &Session) -> Result<()> {
let event = MEvent::from_item(session, MEventType::SET, &Uuid::new_v4().to_string());
client
.send_event(event)
.map_err(|e| anyhow::anyhow!("send_event failed: {e}"))?;
Ok(())
}
fn init_logging() {
let mut b = env_logger::Builder::from_default_env();
if std::env::var("RUST_LOG").is_err() {
b.filter_level(log::LevelFilter::Info);
}
b.target(env_logger::Target::Stderr).init();
}
fn detect_git_branch(cwd: &str) -> Option<String> {
let out = std::process::Command::new("git")
.args(["rev-parse", "--abbrev-ref", "HEAD"])
.current_dir(cwd)
.output()
.ok()?;
if !out.status.success() {
return None;
}
let s = String::from_utf8(out.stdout).ok()?;
let s = s.trim();
if s.is_empty() || s == "HEAD" {
None
} else {
Some(s.to_string())
}
}
fn detect_project_basename(cwd: &str) -> Option<String> {
let out = std::process::Command::new("git")
.args(["rev-parse", "--show-toplevel"])
.current_dir(cwd)
.output()
.ok()?;
if !out.status.success() {
return None;
}
let toplevel = String::from_utf8(out.stdout).ok()?;
let toplevel = toplevel.trim();
std::path::Path::new(toplevel)
.file_name()
.and_then(|s| s.to_str())
.filter(|s| !s.is_empty())
.map(|s| s.to_string())
}
fn detect_operator() -> String {
std::env::var("MARSHAL_OPERATOR")
.or_else(|_| std::env::var("USER"))
.or_else(|_| std::env::var("USERNAME"))
.unwrap_or_else(|_| "anonymous".to_string())
}
fn detect_host() -> HostInfo {
let name = gethostname::gethostname()
.into_string()
.unwrap_or_else(|_| "unknown".to_string());
HostInfo {
name,
os: std::env::consts::OS.to_string(),
arch: std::env::consts::ARCH.to_string(),
}
}