use std::fmt;
use std::path::{Path, PathBuf};
use serde::Serialize;
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use tracing::debug;
#[derive(Debug)]
pub enum TorControlError {
ConnectionFailed(String),
AuthFailed(String),
ProtocolError(String),
Io(std::io::Error),
}
impl fmt::Display for TorControlError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ConnectionFailed(msg) => write!(f, "control port connection failed: {}", msg),
Self::AuthFailed(msg) => write!(f, "control port auth failed: {}", msg),
Self::ProtocolError(msg) => write!(f, "control protocol error: {}", msg),
Self::Io(e) => write!(f, "control port I/O error: {}", e),
}
}
}
impl std::error::Error for TorControlError {}
impl From<std::io::Error> for TorControlError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
}
}
#[derive(Debug, Clone)]
pub enum ControlAuth {
Cookie(PathBuf),
Password(String),
}
impl ControlAuth {
pub fn from_config(auth_str: &str, default_cookie_path: &str) -> Result<Self, TorControlError> {
if auth_str == "cookie" {
Ok(Self::Cookie(PathBuf::from(default_cookie_path)))
} else if let Some(path) = auth_str.strip_prefix("cookie:") {
Ok(Self::Cookie(PathBuf::from(path)))
} else if let Some(password) = auth_str.strip_prefix("password:") {
Ok(Self::Password(password.to_string()))
} else {
Err(TorControlError::AuthFailed(format!(
"unknown control_auth format '{}': expected 'cookie', 'cookie:/path', or 'password:secret'",
auth_str
)))
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct TorMonitoringInfo {
pub bootstrap: u8,
pub circuit_established: bool,
pub traffic_read: u64,
pub traffic_written: u64,
pub network_liveness: String,
pub version: String,
pub dormant: bool,
}
pub struct TorControlClient {
reader: BufReader<Box<dyn AsyncRead + Unpin + Send>>,
writer: Box<dyn AsyncWrite + Unpin + Send>,
}
impl fmt::Debug for TorControlClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TorControlClient").finish_non_exhaustive()
}
}
impl TorControlClient {
pub async fn connect(addr: &str) -> Result<Self, TorControlError> {
#[cfg(unix)]
if is_unix_socket_path(addr) {
return Self::connect_unix(addr).await;
}
Self::connect_tcp(addr).await
}
async fn connect_tcp(addr: &str) -> Result<Self, TorControlError> {
let stream = TcpStream::connect(addr).await.map_err(|e| {
TorControlError::ConnectionFailed(format!(
"failed to connect to control port {}: {}",
addr, e
))
})?;
let (read_half, write_half) = stream.into_split();
debug!(addr = %addr, transport = "tcp", "Connected to Tor control port");
Ok(Self {
reader: BufReader::new(Box::new(read_half)),
writer: Box::new(write_half),
})
}
#[cfg(unix)]
async fn connect_unix(path: &str) -> Result<Self, TorControlError> {
let stream = UnixStream::connect(path).await.map_err(|e| {
TorControlError::ConnectionFailed(format!(
"failed to connect to control socket {}: {}",
path, e
))
})?;
let (read_half, write_half) = stream.into_split();
debug!(path = %path, transport = "unix", "Connected to Tor control port");
Ok(Self {
reader: BufReader::new(Box::new(read_half)),
writer: Box::new(write_half),
})
}
pub async fn authenticate(&mut self, auth: &ControlAuth) -> Result<(), TorControlError> {
let command = match auth {
ControlAuth::Cookie(path) => {
let cookie = read_cookie_file(path)?;
format!("AUTHENTICATE {}\r\n", hex::encode(cookie))
}
ControlAuth::Password(password) => {
let escaped = password.replace('\\', "\\\\").replace('"', "\\\"");
format!("AUTHENTICATE \"{}\"\r\n", escaped)
}
};
self.send_command(&command).await?;
let response = self.read_response().await?;
if response.code != 250 {
return Err(TorControlError::AuthFailed(format!(
"AUTHENTICATE failed: {} {}",
response.code, response.message
)));
}
debug!("Authenticated with Tor control port");
Ok(())
}
async fn getinfo(&mut self, key: &str) -> Result<String, TorControlError> {
let command = format!("GETINFO {}\r\n", key);
self.send_command(&command).await?;
let response = self.read_response().await?;
if response.code != 250 {
return Err(TorControlError::ProtocolError(format!(
"GETINFO {} failed: {} {}",
key, response.code, response.message
)));
}
let prefix = format!("{}=", key);
for line in &response.data_lines {
if let Some(value) = line.strip_prefix(&prefix) {
return Ok(value.to_string());
}
}
Err(TorControlError::ProtocolError(format!(
"GETINFO response missing key '{}'",
key
)))
}
pub async fn get_bootstrap_phase(&mut self) -> Result<u8, TorControlError> {
let raw = self.getinfo("status/bootstrap-phase").await?;
if let Some(progress_start) = raw.find("PROGRESS=") {
let after = &raw[progress_start + 9..];
let digits: String = after.chars().take_while(|c| c.is_ascii_digit()).collect();
if let Ok(progress) = digits.parse::<u8>() {
return Ok(progress);
}
}
Err(TorControlError::ProtocolError(
"could not parse bootstrap progress".into(),
))
}
pub async fn is_circuit_established(&mut self) -> Result<bool, TorControlError> {
let value = self.getinfo("status/circuit-established").await?;
Ok(value.trim() == "1")
}
pub async fn traffic_read(&mut self) -> Result<u64, TorControlError> {
let value = self.getinfo("traffic/read").await?;
value.trim().parse::<u64>().map_err(|_| {
TorControlError::ProtocolError(format!("invalid traffic/read value: '{}'", value))
})
}
pub async fn traffic_written(&mut self) -> Result<u64, TorControlError> {
let value = self.getinfo("traffic/written").await?;
value.trim().parse::<u64>().map_err(|_| {
TorControlError::ProtocolError(format!("invalid traffic/written value: '{}'", value))
})
}
pub async fn network_liveness(&mut self) -> Result<String, TorControlError> {
self.getinfo("network-liveness").await
}
pub async fn version(&mut self) -> Result<String, TorControlError> {
self.getinfo("version").await
}
pub async fn is_dormant(&mut self) -> Result<bool, TorControlError> {
let value = self.getinfo("dormant").await?;
Ok(value.trim() == "1")
}
pub async fn socks_listeners(&mut self) -> Result<Vec<String>, TorControlError> {
let value = self.getinfo("net/listeners/socks").await?;
Ok(value
.split_whitespace()
.map(|s| s.trim_matches('"').to_string())
.collect())
}
pub async fn monitoring_snapshot(&mut self) -> Result<TorMonitoringInfo, TorControlError> {
let bootstrap = self.get_bootstrap_phase().await.unwrap_or(0);
let circuit_established = self.is_circuit_established().await.unwrap_or(false);
let traffic_read = self.traffic_read().await.unwrap_or(0);
let traffic_written = self.traffic_written().await.unwrap_or(0);
let network_liveness = self
.network_liveness()
.await
.unwrap_or_else(|_| "unknown".into());
let version = self.version().await.unwrap_or_else(|_| "unknown".into());
let dormant = self.is_dormant().await.unwrap_or(false);
Ok(TorMonitoringInfo {
bootstrap,
circuit_established,
traffic_read,
traffic_written,
network_liveness,
version,
dormant,
})
}
async fn send_command(&mut self, command: &str) -> Result<(), TorControlError> {
self.writer.write_all(command.as_bytes()).await?;
self.writer.flush().await?;
Ok(())
}
async fn read_response(&mut self) -> Result<ControlResponse, TorControlError> {
let mut data_lines = Vec::new();
let mut line_buf = String::new();
loop {
line_buf.clear();
let n = self.reader.read_line(&mut line_buf).await?;
if n == 0 {
return Err(TorControlError::ProtocolError(
"control port connection closed".into(),
));
}
let line = line_buf.trim_end_matches(['\r', '\n']);
if line.len() < 4 {
return Err(TorControlError::ProtocolError(format!(
"response line too short: '{}'",
line
)));
}
let code: u16 = line[..3].parse().map_err(|_| {
TorControlError::ProtocolError(format!("invalid response code in: '{}'", line))
})?;
let separator = line.as_bytes()[3];
let content = &line[4..];
match separator {
b'-' => {
data_lines.push(content.to_string());
}
b' ' => {
return Ok(ControlResponse {
code,
message: content.to_string(),
data_lines,
});
}
b'+' => {
data_lines.push(content.to_string());
loop {
line_buf.clear();
let n = self.reader.read_line(&mut line_buf).await?;
if n == 0 {
return Err(TorControlError::ProtocolError(
"connection closed during multi-line response".into(),
));
}
let dot_line = line_buf.trim_end_matches(['\r', '\n']);
if dot_line == "." {
break;
}
let unescaped = dot_line.strip_prefix('.').unwrap_or(dot_line);
data_lines.push(unescaped.to_string());
}
}
_ => {
return Err(TorControlError::ProtocolError(format!(
"unexpected separator '{}' in: '{}'",
separator as char, line
)));
}
}
}
}
}
struct ControlResponse {
code: u16,
message: String,
data_lines: Vec<String>,
}
fn read_cookie_file(path: &Path) -> Result<Vec<u8>, TorControlError> {
let data = std::fs::read(path).map_err(|e| {
TorControlError::AuthFailed(format!(
"failed to read cookie file '{}': {}",
path.display(),
e
))
})?;
if data.len() != 32 {
return Err(TorControlError::AuthFailed(format!(
"cookie file '{}' has {} bytes, expected 32",
path.display(),
data.len()
)));
}
Ok(data)
}
#[cfg(unix)]
fn is_unix_socket_path(addr: &str) -> bool {
addr.starts_with('/') || addr.starts_with("./")
}
mod hex {
const HEX_CHARS: &[u8; 16] = b"0123456789abcdef";
pub fn encode(data: Vec<u8>) -> String {
let mut s = String::with_capacity(data.len() * 2);
for byte in data {
s.push(HEX_CHARS[(byte >> 4) as usize] as char);
s.push(HEX_CHARS[(byte & 0x0f) as usize] as char);
}
s
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::tor::mock_control::{self, MockTorControlServer};
use tempfile::TempDir;
#[test]
fn test_control_auth_cookie_default() {
let auth = ControlAuth::from_config("cookie", "/var/run/tor/cookie").unwrap();
match auth {
ControlAuth::Cookie(path) => assert_eq!(path, Path::new("/var/run/tor/cookie")),
_ => panic!("expected Cookie"),
}
}
#[test]
fn test_control_auth_cookie_custom_path() {
let auth = ControlAuth::from_config("cookie:/tmp/my_cookie", "/default").unwrap();
match auth {
ControlAuth::Cookie(path) => assert_eq!(path, Path::new("/tmp/my_cookie")),
_ => panic!("expected Cookie"),
}
}
#[test]
fn test_control_auth_password() {
let auth = ControlAuth::from_config("password:mypass", "/default").unwrap();
match auth {
ControlAuth::Password(p) => assert_eq!(p, "mypass"),
_ => panic!("expected Password"),
}
}
#[test]
fn test_control_auth_invalid() {
let result = ControlAuth::from_config("unknown", "/default");
assert!(result.is_err());
}
#[test]
fn test_hex_encode() {
assert_eq!(hex::encode(vec![0xde, 0xad, 0xbe, 0xef]), "deadbeef");
assert_eq!(hex::encode(vec![0x00, 0xff]), "00ff");
}
#[cfg(unix)]
#[test]
fn test_is_unix_socket_path() {
assert!(is_unix_socket_path("/run/tor/control"));
assert!(is_unix_socket_path("/var/run/tor/control"));
assert!(is_unix_socket_path("./tor-control.sock"));
assert!(!is_unix_socket_path("127.0.0.1:9051"));
assert!(!is_unix_socket_path("tor-daemon:9051"));
assert!(!is_unix_socket_path("localhost:9051"));
}
#[cfg(unix)]
#[tokio::test]
async fn test_connect_unix_socket_nonexistent() {
let result = TorControlClient::connect("/tmp/nonexistent-tor-control.sock").await;
assert!(result.is_err());
let err = format!("{}", result.unwrap_err());
assert!(err.contains("control socket"));
}
#[cfg(unix)]
#[tokio::test]
async fn test_connect_unix_socket_roundtrip() {
let dir = TempDir::new().unwrap();
let sock_path = dir.path().join("control.sock");
let sock_path_str = sock_path.to_str().unwrap().to_string();
let listener = tokio::net::UnixListener::bind(&sock_path).unwrap();
let handle = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let (reader, mut writer) = stream.into_split();
let mut reader = tokio::io::BufReader::new(reader);
let mut line = String::new();
reader.read_line(&mut line).await.unwrap();
assert!(line.starts_with("AUTHENTICATE"));
use tokio::io::AsyncWriteExt;
writer.write_all(b"250 OK\r\n").await.unwrap();
writer.flush().await.unwrap();
});
let mut client = TorControlClient::connect(&sock_path_str).await.unwrap();
let auth = ControlAuth::Password("test".to_string());
client.authenticate(&auth).await.unwrap();
handle.await.unwrap();
}
#[test]
fn test_read_cookie_file_valid() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("cookie");
let cookie_data = vec![0xAA; 32];
std::fs::write(&path, &cookie_data).unwrap();
let loaded = read_cookie_file(&path).unwrap();
assert_eq!(loaded, cookie_data);
}
#[test]
fn test_read_cookie_file_wrong_size() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("cookie");
std::fs::write(&path, [0u8; 16]).unwrap();
assert!(read_cookie_file(&path).is_err());
}
#[test]
fn test_read_cookie_file_nonexistent() {
assert!(read_cookie_file(Path::new("/nonexistent/cookie")).is_err());
}
#[tokio::test]
async fn test_authenticate_password() {
let mock = MockTorControlServer::start().await;
let mut client = TorControlClient::connect(&mock.addr().to_string())
.await
.unwrap();
let auth = ControlAuth::Password("testpass".to_string());
client.authenticate(&auth).await.unwrap();
}
#[tokio::test]
async fn test_authenticate_cookie() {
let mock = MockTorControlServer::start().await;
let dir = TempDir::new().unwrap();
let cookie_path = dir.path().join("cookie");
std::fs::write(&cookie_path, [0xAA; 32]).unwrap();
let mut client = TorControlClient::connect(&mock.addr().to_string())
.await
.unwrap();
let auth = ControlAuth::Cookie(cookie_path);
client.authenticate(&auth).await.unwrap();
}
#[tokio::test]
async fn test_get_bootstrap_phase() {
let mock = MockTorControlServer::start().await;
let mut client = TorControlClient::connect(&mock.addr().to_string())
.await
.unwrap();
let auth = ControlAuth::Password("testpass".to_string());
client.authenticate(&auth).await.unwrap();
let progress = client.get_bootstrap_phase().await.unwrap();
assert_eq!(progress, 100);
}
#[tokio::test]
async fn test_auth_failure() {
let mock = MockTorControlServer::start_with_options(mock_control::MockOptions {
reject_auth: true,
})
.await;
let mut client = TorControlClient::connect(&mock.addr().to_string())
.await
.unwrap();
let auth = ControlAuth::Password("wrongpass".to_string());
let result = client.authenticate(&auth).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_connect_to_closed_port() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
drop(listener);
let result = TorControlClient::connect(&addr.to_string()).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_is_circuit_established() {
let mock = MockTorControlServer::start().await;
let mut client = TorControlClient::connect(&mock.addr().to_string())
.await
.unwrap();
client
.authenticate(&ControlAuth::Password("test".into()))
.await
.unwrap();
assert!(client.is_circuit_established().await.unwrap());
}
#[tokio::test]
async fn test_traffic_counters() {
let mock = MockTorControlServer::start().await;
let mut client = TorControlClient::connect(&mock.addr().to_string())
.await
.unwrap();
client
.authenticate(&ControlAuth::Password("test".into()))
.await
.unwrap();
assert_eq!(client.traffic_read().await.unwrap(), 1048576);
assert_eq!(client.traffic_written().await.unwrap(), 524288);
}
#[tokio::test]
async fn test_network_liveness() {
let mock = MockTorControlServer::start().await;
let mut client = TorControlClient::connect(&mock.addr().to_string())
.await
.unwrap();
client
.authenticate(&ControlAuth::Password("test".into()))
.await
.unwrap();
assert_eq!(client.network_liveness().await.unwrap(), "up");
}
#[tokio::test]
async fn test_version() {
let mock = MockTorControlServer::start().await;
let mut client = TorControlClient::connect(&mock.addr().to_string())
.await
.unwrap();
client
.authenticate(&ControlAuth::Password("test".into()))
.await
.unwrap();
assert_eq!(client.version().await.unwrap(), "0.4.8.10");
}
#[tokio::test]
async fn test_dormant() {
let mock = MockTorControlServer::start().await;
let mut client = TorControlClient::connect(&mock.addr().to_string())
.await
.unwrap();
client
.authenticate(&ControlAuth::Password("test".into()))
.await
.unwrap();
assert!(!client.is_dormant().await.unwrap());
}
#[tokio::test]
async fn test_socks_listeners() {
let mock = MockTorControlServer::start().await;
let mut client = TorControlClient::connect(&mock.addr().to_string())
.await
.unwrap();
client
.authenticate(&ControlAuth::Password("test".into()))
.await
.unwrap();
let listeners = client.socks_listeners().await.unwrap();
assert_eq!(listeners, vec!["127.0.0.1:9050"]);
}
#[tokio::test]
async fn test_monitoring_snapshot() {
let mock = MockTorControlServer::start().await;
let mut client = TorControlClient::connect(&mock.addr().to_string())
.await
.unwrap();
client
.authenticate(&ControlAuth::Password("test".into()))
.await
.unwrap();
let info = client.monitoring_snapshot().await.unwrap();
assert_eq!(info.bootstrap, 100);
assert!(info.circuit_established);
assert_eq!(info.traffic_read, 1048576);
assert_eq!(info.traffic_written, 524288);
assert_eq!(info.network_liveness, "up");
assert_eq!(info.version, "0.4.8.10");
assert!(!info.dormant);
}
}