use std::io::Read;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use crate::audit::AuditLog;
use crate::error::{Error, Result};
use crate::exec::ExecResult;
use crate::output::{bound, clean};
use crate::policy::Policy;
use crate::redact::redact;
use crate::transport::{self, local::LocalPty, Transport};
const US: u8 = 0x1f;
pub struct Session {
io: Box<dyn Transport>,
token: String,
policy: Option<Policy>,
audit: Option<AuditLog>,
timeout: Duration,
max_output: usize,
poisoned: bool,
}
impl Session {
pub fn local() -> Result<Self> {
let pty = LocalPty::spawn("bash", &["--norc", "--noprofile"])?;
Self::from_transport(Box::new(pty))
}
#[cfg(feature = "ssh")]
pub fn ssh(config: crate::transport::ssh::SshConfig) -> Result<Self> {
let t = crate::transport::ssh::SshTransport::connect(config)?;
Self::from_transport(Box::new(t))
}
fn from_transport(mut io: Box<dyn Transport>) -> Result<Self> {
transport::shell_init(io.as_mut())?;
Ok(Self {
io,
token: unique_token(),
policy: None,
audit: None,
timeout: Duration::from_secs(30),
max_output: 100_000,
poisoned: false,
})
}
pub fn with_policy(mut self, policy: Policy) -> Self {
self.policy = Some(policy);
self
}
pub fn with_audit(mut self, audit: AuditLog) -> Self {
self.audit = Some(audit);
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_max_output(mut self, max: usize) -> Self {
self.max_output = max;
self
}
pub fn is_poisoned(&self) -> bool {
self.poisoned
}
pub fn exec(&mut self, command: &str) -> Result<ExecResult> {
if self.poisoned {
return Err(Error::SessionPoisoned);
}
if let Some(p) = &self.policy {
if let Err(reason) = p.check(command) {
return Err(Error::PolicyDenied(reason));
}
}
let start_m = format!("__EXECKIT_{}__", self.token);
let end_m = format!("__EXECKITEND_{}__", self.token);
let payload = format!(
"__E=$(umask 077; mktemp 2>/dev/null||{{ f=/tmp/execkitE_{tok}; : >\"$f\"; echo \"$f\"; }}); \
{{ {cmd} ; }} 2>\"$__E\"; \
printf '\\n{start}\\037%d\\037%s\\037' \"$?\" \"$PWD\"; cat \"$__E\" 2>/dev/null; \
printf '{end}\\n'; rm -f \"$__E\"\n",
tok = self.token,
cmd = command,
start = start_m,
end = end_m,
);
let started = Instant::now();
self.io.write_all(payload.as_bytes())?;
let start_b = start_m.as_bytes();
let end_b = end_m.as_bytes();
let max_acc = self.max_output.saturating_mul(2).max(65_536);
let mut acc: Vec<u8> = Vec::new();
let mut overflowed = false;
let deadline = Instant::now() + self.timeout;
loop {
let now = Instant::now();
if now >= deadline {
self.poisoned = true;
return Err(Error::StillRunning);
}
let chunk = match self.io.recv_timeout(deadline - now) {
Some(c) => c,
None => {
self.poisoned = true;
return Err(Error::StillRunning);
}
};
acc.extend_from_slice(&chunk);
if acc.len() > max_acc {
let keep = max_acc / 2;
let tail_start = acc.len() - keep;
let mut compacted = Vec::with_capacity(keep * 2);
compacted.extend_from_slice(&acc[..keep]);
compacted.extend_from_slice(&acc[tail_start..]);
acc = compacted;
overflowed = true;
}
let Some(end_pos) = find(&acc, end_b) else {
continue;
};
let Some(start_pos) = find(&acc[..end_pos], start_b) else {
continue;
};
let between = &acc[start_pos + start_b.len()..end_pos];
let seps: Vec<usize> = between
.iter()
.enumerate()
.filter(|(_, b)| **b == US)
.map(|(i, _)| i)
.collect();
if seps.len() < 3 {
continue;
}
let exit_code: i32 = String::from_utf8_lossy(&between[seps[0] + 1..seps[1]])
.trim()
.parse()
.unwrap_or(-1);
let cwd = String::from_utf8_lossy(&between[seps[1] + 1..seps[2]]).into_owned();
let raw_err = clean(&String::from_utf8_lossy(&between[seps[2] + 1..]));
let raw_out = clean(&String::from_utf8_lossy(&acc[..start_pos]));
let (stdout, t1) = bound(&redact(&raw_out), self.max_output);
let (stderr, t2) = bound(&redact(&raw_err), self.max_output);
let result = ExecResult {
command: command.to_string(),
stdout,
stderr,
exit_code,
duration_ms: started.elapsed().as_millis() as u64,
cwd,
truncated: t1 || t2 || overflowed,
};
if let Some(a) = &self.audit {
if let Err(e) = a.record(&result) {
eprintln!("execkit: audit write failed: {e}");
}
}
return Ok(result);
}
}
}
fn find(hay: &[u8], needle: &[u8]) -> Option<usize> {
if needle.is_empty() || hay.len() < needle.len() {
return None;
}
hay.windows(needle.len()).position(|w| w == needle)
}
fn unique_token() -> String {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
let mut rnd = [0u8; 8];
if let Ok(mut f) = std::fs::File::open("/dev/urandom") {
let _ = f.read_exact(&mut rnd);
}
let rhex: String = rnd.iter().map(|b| format!("{b:02x}")).collect();
format!("{nanos:x}{n:x}{rhex}")
}