use std::{
cell::RefCell,
collections::{HashMap, HashSet, VecDeque},
fmt::{self, Debug},
fs::{DirBuilder, File, OpenOptions, Permissions},
io::{Error as IoError, Write},
ops::{Deref, DerefMut},
os::{
fd::{AsRawFd, FromRawFd},
unix::fs::{DirBuilderExt, OpenOptionsExt, PermissionsExt},
},
path::Path,
time::{Duration, Instant},
};
use libc::pid_t;
use mio::{
Events, Interest, Poll, Token,
net::{UnixListener, UnixStream},
};
use nix::{
sys::signal::{Signal, kill},
unistd::Pid,
};
use sozu_command_lib::{
channel::Channel,
config::Config,
proto::command::{
Event, EventKind, Request, ResponseContent, ResponseStatus, RunState, Status,
WorkerRequest, WorkerResponse, request::RequestType, response_content::ContentType,
},
ready::Ready,
scm_socket::{Listeners, ScmSocket, ScmSocketError},
state::ConfigState,
};
use sozu_lib::metrics::names;
use super::upgrade::SerializedWorkerSession;
use crate::{
command::{
sessions::{
ClientResult, ClientSession, OptionalClient, WorkerResult, WorkerSession, wants_to_tick,
},
upgrade::UpgradeData,
},
util::{UtilError, disable_close_on_exec, enable_close_on_exec, get_executable_path},
worker::{WorkerError, fork_main_into_worker},
};
pub type ClientId = u32;
pub type SessionId = usize;
pub type TaskId = usize;
pub type WorkerId = u32;
pub type RequestId = String;
#[allow(unused)]
pub trait Gatherer {
fn inc_expected_responses(&mut self, count: usize);
fn has_finished(&self) -> bool;
fn on_message(
&mut self,
server: &mut Server,
client: &mut OptionalClient,
worker_id: WorkerId,
message: WorkerResponse,
);
}
#[allow(unused)]
pub trait GatheringTask: Debug {
fn client_token(&self) -> Option<Token>;
fn get_gatherer(&mut self) -> &mut dyn Gatherer;
fn on_finish(
self: Box<Self>,
server: &mut Server,
client: &mut OptionalClient,
timed_out: bool,
);
}
pub trait MessageClient {
fn finish_ok<T: Into<String>>(&mut self, message: T);
fn finish_ok_with_content<T: Into<String>>(&mut self, content: ResponseContent, message: T);
fn finish_failure<T: Into<String>>(&mut self, message: T);
fn return_processing<T: Into<String>>(&mut self, message: T);
fn return_processing_with_content<S: Into<String>>(
&mut self,
message: S,
content: ResponseContent,
);
}
pub enum Timeout {
None,
Default,
#[allow(unused)]
Custom(Duration),
}
#[derive(Debug)]
struct TaskContainer {
job: Box<dyn GatheringTask>,
timeout: Option<Instant>,
}
#[derive(Debug, Default)]
pub struct DefaultGatherer {
pub ok: usize,
pub errors: usize,
pub responses: Vec<(WorkerId, WorkerResponse)>,
pub expected_responses: usize,
}
#[allow(unused)]
impl Gatherer for DefaultGatherer {
fn inc_expected_responses(&mut self, count: usize) {
self.expected_responses += count;
}
fn has_finished(&self) -> bool {
self.ok + self.errors >= self.expected_responses
}
fn on_message(
&mut self,
server: &mut Server,
client: &mut OptionalClient,
worker_id: WorkerId,
message: WorkerResponse,
) {
match ResponseStatus::try_from(message.status) {
Ok(ResponseStatus::Ok) => self.ok += 1,
Ok(ResponseStatus::Failure) => self.errors += 1,
Ok(ResponseStatus::Processing) => client.return_processing(format!(
"Worker {} is processing {}. {}",
worker_id, message.id, message.message
)),
Err(e) => warn!("error decoding response status: {}", e),
}
self.responses.push((worker_id, message));
}
}
#[derive(thiserror::Error, Debug)]
pub enum HubError {
#[error("could not create main server: {0}")]
CreateServer(ServerError),
#[error("could not get executable path")]
GetExecutablePath(UtilError),
#[error("could not create SCM socket for worker {0}: {1}")]
CreateScmSocket(u32, ScmSocketError),
}
#[derive(Debug)]
pub struct CommandHub {
pub server: Server,
clients: HashMap<Token, ClientSession>,
tasks: HashMap<TaskId, TaskContainer>,
command_socket_path: std::sync::Arc<str>,
}
impl Deref for CommandHub {
type Target = Server;
fn deref(&self) -> &Self::Target {
&self.server
}
}
impl DerefMut for CommandHub {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.server
}
}
impl CommandHub {
pub fn new(
unix_listener: UnixListener,
config: Config,
executable_path: String,
) -> Result<Self, HubError> {
let command_socket_path: std::sync::Arc<str> = config
.command_socket_path()
.unwrap_or_else(|_| "unknown".to_owned())
.into();
Ok(Self {
server: Server::new(unix_listener, config, executable_path)
.map_err(HubError::CreateServer)?,
clients: HashMap::new(),
tasks: HashMap::new(),
command_socket_path,
})
}
fn register_client(&mut self, mut stream: UnixStream) {
let token = self.next_session_token();
if let Err(err) = self.register(token, &mut stream) {
error!(
"Could not register client (token={:?}): {} — dropping connection",
token, err
);
return;
}
let peer_cred = peer_cred_from_stream(&stream);
let actor_comm = peer_cred.pid.and_then(peer_comm);
let actor_user = peer_cred.uid.and_then(peer_user);
let channel = Channel::new(stream, 4096, self.config.max_command_buffer_size);
let id = self.next_client_id();
let session = ClientSession::new(
channel,
id,
token,
peer_cred,
actor_comm,
actor_user,
self.command_socket_path.clone(),
);
info!(
"Register new client: {} (actor_uid={} actor_pid={} actor_user={} actor_comm={})",
id,
session.actor_uid_display(),
session.actor_pid_display(),
session.actor_user_display(),
session.actor_comm_display()
);
debug!("registering client {:?}", session);
self.clients.insert(token, session);
}
fn flush_pending_audit_events(&mut self) {
let events: Vec<Event> = self.server.pending_audit_events.drain(..).collect();
if events.is_empty() {
return;
}
for event in events {
for client_token in &self.server.event_subscribers {
if let Some(client) = self.clients.get_mut(client_token) {
client.return_processing_with_content(
String::from("main"),
ContentType::Event(event.clone()).into(),
);
}
}
}
}
fn get_client_mut(&mut self, token: &Token) -> Option<(&mut Server, &mut ClientSession)> {
self.clients
.get_mut(token)
.map(|client| (&mut self.server, client))
}
pub fn from_upgrade_data(upgrade_data: UpgradeData) -> Result<Self, HubError> {
let UpgradeData {
command_socket_fd,
config,
workers,
state,
next_client_id,
next_session_id,
next_task_id,
next_worker_id,
boot_generation,
} = upgrade_data;
let executable_path =
unsafe { get_executable_path().map_err(HubError::GetExecutablePath)? };
let unix_listener = unsafe { UnixListener::from_raw_fd(command_socket_fd) };
let command_buffer_size = config.command_buffer_size;
let max_command_buffer_size = config.max_command_buffer_size;
let command_socket_path: std::sync::Arc<str> = config
.command_socket_path()
.unwrap_or_else(|_| "unknown".to_owned())
.into();
let mut server =
Server::new(unix_listener, config, executable_path).map_err(HubError::CreateServer)?;
server.state = state;
server.update_counts();
server.next_client_id = next_client_id;
server.next_session_id = next_session_id;
server.next_task_id = next_task_id;
server.next_worker_id = next_worker_id;
server.boot_generation = boot_generation;
for worker in workers
.iter()
.filter(|w| w.run_state != RunState::Stopped && w.run_state != RunState::Stopping)
{
let worker_stream = unsafe { UnixStream::from_raw_fd(worker.channel_fd) };
let channel: Channel<WorkerRequest, WorkerResponse> =
Channel::new(worker_stream, command_buffer_size, max_command_buffer_size);
let scm_socket = ScmSocket::new(worker.scm_fd)
.map_err(|scm_err| HubError::CreateScmSocket(worker.id, scm_err))?;
if let Err(err) = server.register_worker(worker.id, worker.pid, channel, scm_socket) {
error!("could not register worker: {}", err);
}
}
Ok(CommandHub {
server,
clients: HashMap::new(),
tasks: HashMap::new(),
command_socket_path,
})
}
pub fn run(&mut self) -> bool {
let mut events = Events::with_capacity(100);
debug!("running the command hub: {:?}", self);
loop {
let run_state = self.run_state;
let now = Instant::now();
let mut tasks = std::mem::take(&mut self.tasks);
let mut queued_tasks = std::mem::take(&mut self.server.queued_tasks);
self.tasks = tasks
.drain()
.chain(queued_tasks.drain())
.filter_map(|(task_id, mut task)| {
if task.job.get_gatherer().has_finished() {
self.handle_finishing_task(task_id, task, false);
return None;
}
if let Some(timeout) = task.timeout {
if timeout < now {
self.handle_finishing_task(task_id, task, true);
return None;
}
}
Some((task_id, task))
})
.collect();
let next_timeout = self.tasks.values().filter_map(|t| t.timeout).max();
let mut poll_timeout = next_timeout.map(|t| t.saturating_duration_since(now));
if self.run_state == ServerState::Stopping {
self.clients
.retain(|_, s| s.channel.back_buf.available_data() > 0);
if self.clients.is_empty() {
return self.server.upgrading;
}
}
let sessions_to_tick = self
.clients
.iter()
.filter_map(|(t, s)| {
if wants_to_tick(&s.channel) {
Some((*t, Ready::EMPTY, None))
} else {
None
}
})
.chain(self.workers.iter().filter_map(|(token, session)| {
if session.run_state != RunState::Stopped && wants_to_tick(&session.channel) {
Some((*token, Ready::EMPTY, None))
} else {
None
}
}))
.collect::<Vec<_>>();
let workers_to_spawn = self.workers_to_spawn();
if !sessions_to_tick.is_empty() || workers_to_spawn > 0 {
poll_timeout = Some(Duration::default());
}
events.clear();
trace!("Tasks: {:?}", self.tasks);
trace!("Sessions to tick: {:?}", sessions_to_tick);
trace!("Polling timeout: {:?}", poll_timeout);
match self.poll.poll(&mut events, poll_timeout) {
Ok(()) => {}
Err(error) => error!("Error while polling: {:?}", error),
}
self.automatic_worker_spawn(workers_to_spawn);
let events = sessions_to_tick.into_iter().chain(
events
.into_iter()
.map(|event| (event.token(), Ready::from(event), Some(event))),
);
for (token, ready, event) in events {
match token {
Token(0) => {
if run_state == ServerState::Stopping {
continue;
}
if ready.is_readable() {
while let Ok((stream, _addr)) = self.unix_listener.accept() {
self.register_client(stream);
}
}
}
token => {
trace!("{:?} got event: {:?}", token, event);
if let Some((server, client)) = self.get_client_mut(&token) {
client.update_readiness(ready);
match client.ready() {
ClientResult::NothingToDo => {}
ClientResult::NewRequest(request) => {
debug!("Received new request: {:?}", request);
server.handle_client_request(client, request);
self.flush_pending_audit_events();
}
ClientResult::CloseSession => {
info!("Closing client {}", client.id);
debug!("closing client {:?}", client);
self.event_subscribers.remove(&token);
self.clients.remove(&token);
}
}
} else if let Some(worker) = self.workers.get_mut(&token) {
if run_state == ServerState::Stopping {
continue;
}
worker.update_readiness(ready);
let worker_id = worker.id;
match worker.ready() {
WorkerResult::NothingToDo => {}
WorkerResult::NewResponses(responses) => {
for response in responses {
self.handle_worker_response(worker_id, response);
}
}
WorkerResult::CloseSession => self.handle_worker_close(&token),
}
}
}
}
}
}
}
fn handle_worker_response(&mut self, worker_id: WorkerId, response: WorkerResponse) {
if let Some(ResponseContent {
content_type: Some(ContentType::Event(event)),
}) = response.content
{
if event.kind == EventKind::MetricDetailChanged as i32 {
if let Some(transition) = event.metric_detail.as_ref() {
crate::command::requests::audit_worker_metric_detail_transition(
&mut self.server,
worker_id,
transition,
);
}
}
for client_token in &self.server.event_subscribers {
if let Some(client) = self.clients.get_mut(client_token) {
client.return_processing_with_content(
format!("{worker_id}"),
ContentType::Event(event.clone()).into(),
);
}
}
return;
}
let Some(task_id) = self.in_flight.get(&response.id).copied() else {
warn!("Got a response for an unknown task: {}", response);
return;
};
let task = match self.tasks.get_mut(&task_id) {
Some(task) => task,
None => {
warn!("Got a response for an unknown task");
return;
}
};
let client = &mut task
.job
.client_token()
.and_then(|token| self.clients.get_mut(&token));
task.job
.get_gatherer()
.on_message(&mut self.server, client, worker_id, response);
}
fn handle_finishing_task(&mut self, task_id: TaskId, task: TaskContainer, timed_out: bool) {
if timed_out {
debug!("Task timeout: {:?}", task);
} else {
debug!("Task finish: {:?}", task);
}
let client = &mut task
.job
.client_token()
.and_then(|token| self.clients.get_mut(&token));
task.job.on_finish(&mut self.server, client, false);
self.in_flight
.retain(|_, in_flight_task_id| *in_flight_task_id != task_id);
}
}
#[derive(thiserror::Error, Debug)]
pub enum ServerError {
#[error("Could not create Poll with MIO: {0:?}")]
CreatePoll(IoError),
#[error("Could not register channel in MIO registry: {0:?}")]
RegisterChannel(IoError),
#[error("Could not fork the main into a new worker: {0}")]
ForkMain(WorkerError),
#[error("Did not find worker. This should NOT happen.")]
WorkerNotFound,
#[error("could not enable cloexec: {0}")]
EnableCloexec(UtilError),
#[error("could not disable cloexec: {0}")]
DisableCloexec(UtilError),
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum ServerState {
Running,
WorkersStopping,
Stopping,
}
pub struct Server {
pub config: Config,
pub event_subscribers: HashSet<Token>,
pub executable_path: String,
in_flight: HashMap<RequestId, TaskId>,
next_client_id: ClientId,
next_session_id: SessionId,
next_task_id: TaskId,
next_worker_id: WorkerId,
pub pending_audit_events: VecDeque<Event>,
pub audit_log_writer: Option<RefCell<File>>,
pub audit_log_json_writer: Option<RefCell<File>>,
pub boot_generation: u32,
poll: Poll,
queued_tasks: HashMap<TaskId, TaskContainer>,
pub state: ConfigState,
pub run_state: ServerState,
pub upgrading: bool,
unix_listener: UnixListener,
pub workers: HashMap<Token, WorkerSession>,
}
impl Server {
fn new(
mut unix_listener: UnixListener,
config: Config,
executable_path: String,
) -> Result<Self, ServerError> {
let poll = mio::Poll::new().map_err(ServerError::CreatePoll)?;
poll.registry()
.register(
&mut unix_listener,
Token(0),
Interest::READABLE | Interest::WRITABLE,
)
.map_err(ServerError::RegisterChannel)?;
let audit_log_writer = match config.audit_logs_target.as_deref() {
Some(path) => match open_audit_log_file(path) {
Ok(file) => Some(RefCell::new(file)),
Err(err) => {
error!(
"Could not open audit log file {:?}: {}. Audit lines will be routed only through the standard logger.",
path, err
);
None
}
},
None => None,
};
let audit_log_json_writer = match config.audit_logs_json_target.as_deref() {
Some(path) => match open_audit_log_file(path) {
Ok(file) => Some(RefCell::new(file)),
Err(err) => {
error!(
"Could not open audit JSON log file {:?}: {}. JSON audit sink disabled.",
path, err
);
None
}
},
None => None,
};
Ok(Self {
config,
event_subscribers: HashSet::new(),
executable_path,
in_flight: HashMap::new(),
next_client_id: 0,
next_session_id: 1, next_task_id: 0,
next_worker_id: 0,
pending_audit_events: VecDeque::new(),
audit_log_writer,
audit_log_json_writer,
boot_generation: 0,
poll,
queued_tasks: HashMap::new(),
state: ConfigState::new(),
run_state: ServerState::Running,
upgrading: false,
unix_listener,
workers: HashMap::new(),
})
}
pub fn append_audit_line(&self, line: &str) {
let Some(writer) = self.audit_log_writer.as_ref() else {
return;
};
let mut writer = writer.borrow_mut();
if let Err(err) = writeln!(writer, "{line}") {
error!("Could not append to audit log file: {}", err);
}
}
pub fn append_audit_json(&self, json: &str) {
let Some(writer) = self.audit_log_json_writer.as_ref() else {
return;
};
let mut writer = writer.borrow_mut();
if let Err(err) = writeln!(writer, "{json}") {
error!("Could not append to audit JSON file: {}", err);
}
}
}
fn open_audit_log_file(path: &str) -> Result<File, IoError> {
let path = Path::new(path);
if let Some(parent) = path.parent() {
let _ = DirBuilder::new().recursive(true).mode(0o750).create(parent);
}
let pre_existing = path.exists();
let pre_existing_mode = if pre_existing {
std::fs::metadata(path)
.ok()
.map(|m| m.permissions().mode() & 0o777)
} else {
None
};
let file = OpenOptions::new()
.append(true)
.create(true)
.mode(0o640)
.open(path)?;
if let Some(prev) = pre_existing_mode
&& prev != 0o640
{
if prev & !0o640 != 0 {
warn!(
"audit log file {} pre-existed with mode 0o{:o} — narrowing to 0o640",
path.display(),
prev
);
}
if let Err(e) = file.set_permissions(Permissions::from_mode(0o640)) {
warn!(
"could not narrow audit log file {} permissions to 0o640: {:?}",
path.display(),
e
);
}
}
Ok(file)
}
impl Server {
pub fn launch_new_worker(
&mut self,
listeners: Option<Listeners>,
) -> Result<&mut WorkerSession, ServerError> {
let worker_id = self.next_worker_id();
let (worker_pid, main_to_worker_channel, main_to_worker_scm) = fork_main_into_worker(
&worker_id.to_string(),
&self.config,
self.executable_path.clone(),
&self.state,
Some(listeners.unwrap_or_default()),
)
.map_err(ServerError::ForkMain)?;
let worker_session = self.register_worker(
worker_id,
worker_pid,
main_to_worker_channel,
main_to_worker_scm,
)?;
worker_session.send(&WorkerRequest {
id: format!("INITIAL-STATUS-{worker_id}"),
content: RequestType::Status(Status {}).into(),
});
Ok(worker_session)
}
pub fn update_counts(&mut self) {
gauge!(names::configuration::CLUSTERS, self.state.clusters.len());
gauge!(names::configuration::BACKENDS, self.state.count_backends());
gauge!(
names::configuration::FRONTENDS,
self.state.count_frontends()
);
}
pub fn push_audit_event(&mut self, event: Event) {
self.pending_audit_events.push_back(event);
}
fn next_session_token(&mut self) -> Token {
let token = Token(self.next_session_id);
self.next_session_id += 1;
token
}
fn next_client_id(&mut self) -> ClientId {
let id = self.next_client_id;
self.next_client_id += 1;
id
}
fn next_task_id(&mut self) -> TaskId {
let id = self.next_task_id;
self.next_task_id += 1;
id
}
fn next_worker_id(&mut self) -> WorkerId {
let id = self.next_worker_id;
self.next_worker_id += 1;
id
}
fn register(&mut self, token: Token, stream: &mut UnixStream) -> Result<(), ServerError> {
self.poll
.registry()
.register(stream, token, Interest::READABLE | Interest::WRITABLE)
.map_err(ServerError::RegisterChannel)
}
pub fn get_active_worker_by_id(&self, id: WorkerId) -> Option<&WorkerSession> {
self.workers
.values()
.find(|worker| worker.id == id && worker.is_active())
}
pub fn register_worker(
&mut self,
worker_id: WorkerId,
pid: pid_t,
mut channel: Channel<WorkerRequest, WorkerResponse>,
scm_socket: ScmSocket,
) -> Result<&mut WorkerSession, ServerError> {
let token = self.next_session_token();
self.register(token, &mut channel.sock)?;
self.workers.insert(
token,
WorkerSession::new(channel, worker_id, pid, token, scm_socket),
);
self.workers
.get_mut(&token)
.ok_or(ServerError::WorkerNotFound)
}
pub fn new_task(&mut self, job: Box<dyn GatheringTask>, timeout: Timeout) -> TaskId {
let task_id = self.next_task_id();
let timeout = match timeout {
Timeout::None => None,
Timeout::Default => Some(Duration::from_secs(self.config.worker_timeout as u64)),
Timeout::Custom(duration) => Some(duration),
}
.map(|duration| Instant::now() + duration);
self.queued_tasks
.insert(task_id, TaskContainer { job, timeout });
task_id
}
pub fn scatter(
&mut self,
request: Request,
job: Box<dyn GatheringTask>,
timeout: Timeout,
target: Option<WorkerId>, ) {
let task_id = self.new_task(job, timeout);
self.scatter_on(request, task_id, 0, target);
}
pub fn scatter_on(
&mut self,
request: Request,
task_id: TaskId,
request_id: usize,
target: Option<WorkerId>,
) {
let task = match self.queued_tasks.get_mut(&task_id) {
Some(task) => task,
None => {
error!("no task found with id {}", task_id);
return;
}
};
let mut worker_count = 0;
let mut worker_request = WorkerRequest {
id: String::new(),
content: request,
};
for worker in self.workers.values_mut().filter(|w| {
target
.map(|id| id == w.id && w.run_state != RunState::Stopped)
.unwrap_or(w.run_state != RunState::Stopped)
}) {
worker_count += 1;
worker_request.id = format!(
"{}-{}-{}-{}",
worker_request.content.short_name(),
worker.id,
task_id,
request_id,
);
debug!("scattering to worker {}: {:?}", worker.id, worker_request);
worker.send(&worker_request);
self.in_flight.insert(worker_request.id, task_id);
}
task.job.get_gatherer().inc_expected_responses(worker_count);
}
pub fn cancel_task(&mut self, task_id: TaskId) {
self.queued_tasks.remove(&task_id);
}
pub fn handle_worker_close(&mut self, token: &Token) {
match self.workers.get(token) {
Some(worker) => {
info!("closing session of worker {}", worker.id);
trace!("closing worker session {:?}", worker);
}
None => {
error!("No worker exists with token {:?}", token);
return;
}
};
self.close_worker(token);
}
pub fn workers_to_spawn(&self) -> u16 {
if self.config.worker_automatic_restart && self.run_state == ServerState::Running {
self.config
.worker_count
.saturating_sub(self.alive_workers() as u16)
} else {
0
}
}
pub fn automatic_worker_spawn(&mut self, count: u16) {
if count == 0 {
return;
}
info!("Automatically restarting {} workers", count);
for _ in 0..count {
if let Err(err) = self.launch_new_worker(None) {
error!("could not launch new worker: {}", err);
}
}
}
fn alive_workers(&self) -> usize {
self.workers
.values()
.filter(|worker| worker.is_active())
.count()
}
pub fn close_worker(&mut self, token: &Token) {
let worker = match self.workers.get_mut(token) {
Some(w) => w,
None => {
error!("No worker exists with token {:?}", token);
return;
}
};
match kill(Pid::from_raw(worker.pid), Signal::SIGKILL) {
Ok(()) => info!("Worker {} was successfully killed", worker.id),
Err(_) => info!("worker {} was already dead", worker.id),
}
worker.run_state = RunState::Stopped;
}
pub fn disable_cloexec_before_upgrade(&mut self) -> Result<i32, ServerError> {
trace!(
"disabling cloexec on listener with file descriptor: {}",
self.unix_listener.as_raw_fd()
);
disable_close_on_exec(self.unix_listener.as_raw_fd()).map_err(ServerError::DisableCloexec)
}
pub fn enable_cloexec_after_upgrade(&mut self) -> Result<i32, ServerError> {
for worker in self.workers.values_mut() {
if worker.run_state == RunState::Running {
let _ = enable_close_on_exec(worker.channel.fd()).map_err(|e| {
error!(
"could not enable close on exec for worker {}: {}",
worker.id, e
);
});
}
}
enable_close_on_exec(self.unix_listener.as_raw_fd()).map_err(ServerError::EnableCloexec)
}
pub fn generate_upgrade_data(&self) -> UpgradeData {
UpgradeData {
command_socket_fd: self.unix_listener.as_raw_fd(),
config: self.config.clone(),
workers: self
.workers
.values()
.filter_map(|session| match SerializedWorkerSession::try_from(session) {
Ok(serialized_session) => Some(serialized_session),
Err(err) => {
error!("failed to serialize worker session: {}", err);
None
}
})
.collect(),
state: self.state.clone(),
next_client_id: self.next_client_id,
next_session_id: self.next_session_id,
next_task_id: self.next_task_id,
next_worker_id: self.next_worker_id,
boot_generation: self.boot_generation,
}
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct PeerCred {
pub uid: Option<u32>,
pub gid: Option<u32>,
pub pid: Option<i32>,
}
#[cfg(any(target_os = "linux", target_os = "android"))]
pub(crate) fn peer_cred_from_stream(stream: &UnixStream) -> PeerCred {
use nix::sys::socket::{getsockopt, sockopt::PeerCredentials};
match getsockopt(stream, PeerCredentials) {
Ok(creds) => PeerCred {
uid: Some(creds.uid()),
gid: Some(creds.gid()),
pid: Some(creds.pid()),
},
Err(err) => {
warn!("Could not read SO_PEERCRED on command socket: {}", err);
PeerCred::default()
}
}
}
#[cfg(not(any(target_os = "linux", target_os = "android")))]
pub(crate) fn peer_cred_from_stream(_stream: &UnixStream) -> PeerCred {
PeerCred::default()
}
#[cfg(any(target_os = "linux", target_os = "android"))]
pub(crate) fn peer_comm(pid: i32) -> Option<String> {
use std::io::Read;
let _stat = std::fs::read_to_string(format!("/proc/{pid}/stat")).ok()?;
let mut buf = String::new();
let path = format!("/proc/{pid}/comm");
std::fs::File::open(&path)
.ok()?
.read_to_string(&mut buf)
.ok()?;
let trimmed = buf.trim_end_matches(['\n', '\r']);
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_owned())
}
}
#[cfg(not(any(target_os = "linux", target_os = "android")))]
pub(crate) fn peer_comm(_pid: i32) -> Option<String> {
None
}
#[cfg(any(target_os = "linux", target_os = "android"))]
pub(crate) fn peer_user(uid: u32) -> Option<String> {
use std::sync::Mutex;
use nix::unistd::{Uid, User};
const MAX_PEER_USER_CACHE: usize = 16;
static CACHE: Mutex<Vec<(u32, Option<String>)>> = Mutex::new(Vec::new());
if let Ok(guard) = CACHE.lock()
&& let Some((_, cached)) = guard.iter().find(|(k, _)| *k == uid)
{
return cached.clone();
}
let resolved = match User::from_uid(Uid::from_raw(uid)) {
Ok(Some(user)) => Some(user.name),
Ok(None) => None,
Err(err) => {
warn!("Could not resolve username for uid {}: {}", uid, err);
None
}
};
if let Ok(mut guard) = CACHE.lock() {
if guard.len() >= MAX_PEER_USER_CACHE {
guard.remove(0);
}
guard.push((uid, resolved.clone()));
}
resolved
}
#[cfg(not(any(target_os = "linux", target_os = "android")))]
pub(crate) fn peer_user(_uid: u32) -> Option<String> {
None
}
impl Debug for Server {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Server")
.field("config", &self.config)
.field("event_subscribers", &self.event_subscribers)
.field("executable_path", &self.executable_path)
.field("in_flight", &self.in_flight)
.field("next_client_id", &self.next_client_id)
.field("next_session_id", &self.next_session_id)
.field("next_task_id", &self.next_task_id)
.field("next_worker_id", &self.next_worker_id)
.field("pending_audit_events", &self.pending_audit_events.len())
.field("poll", &self.poll)
.field("queued_tasks", &self.queued_tasks)
.field("run_state", &self.run_state)
.field("unix_listener", &self.unix_listener)
.field("workers", &self.workers)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use sozu_command_lib::{
config::Config,
proto::command::{
AddBackend, Cluster, RequestHttpFrontend, RequestTcpFrontend, SocketAddress,
request::RequestType,
},
};
use sozu_lib::metrics::METRICS;
use sozu_command_lib::proto::command::{PathRule, RulePosition, filtered_metrics};
fn read_gauge(key: &str) -> Option<u64> {
METRICS.with(|metrics| {
let mut m = metrics.borrow_mut();
let proxy_metrics = m.dump_local_proxy_metrics();
proxy_metrics.get(key).and_then(|fm| match &fm.inner {
Some(filtered_metrics::Inner::Gauge(v)) => Some(*v),
_ => None,
})
})
}
fn create_test_server() -> Server {
let dir = tempfile::tempdir().expect("Could not create temp dir");
let socket_path = dir.path().join("test.sock");
let unix_listener = UnixListener::bind(&socket_path).expect("Could not bind socket");
Server::new(unix_listener, Config::default(), "sozu".to_owned())
.expect("Could not create server")
}
#[test]
fn update_counts_reflects_state() {
let mut server = create_test_server();
server.update_counts();
assert_eq!(read_gauge(names::configuration::CLUSTERS), Some(0));
assert_eq!(read_gauge(names::configuration::BACKENDS), Some(0));
assert_eq!(read_gauge(names::configuration::FRONTENDS), Some(0));
server
.state
.dispatch(
&RequestType::AddCluster(Cluster {
cluster_id: String::from("cluster_1"),
..Default::default()
})
.into(),
)
.expect("Could not add cluster");
for i in 0..3 {
server
.state
.dispatch(
&RequestType::AddBackend(AddBackend {
cluster_id: String::from("cluster_1"),
backend_id: format!("cluster_1-{i}"),
address: SocketAddress::new_v4(127, 0, 0, 1, 1026 + i as u16),
..Default::default()
})
.into(),
)
.expect("Could not add backend");
}
server
.state
.dispatch(
&RequestType::AddHttpFrontend(RequestHttpFrontend {
cluster_id: Some(String::from("cluster_1")),
hostname: String::from("example.com"),
path: PathRule::prefix(String::from("/")),
address: SocketAddress::new_v4(0, 0, 0, 0, 8080),
position: RulePosition::Tree.into(),
..Default::default()
})
.into(),
)
.expect("Could not add frontend");
server
.state
.dispatch(
&RequestType::AddTcpFrontend(RequestTcpFrontend {
cluster_id: String::from("cluster_1"),
address: SocketAddress::new_v4(0, 0, 0, 0, 5432),
..Default::default()
})
.into(),
)
.expect("Could not add TCP frontend");
assert_eq!(read_gauge(names::configuration::CLUSTERS), Some(0));
server.update_counts();
assert_eq!(read_gauge(names::configuration::CLUSTERS), Some(1));
assert_eq!(read_gauge(names::configuration::BACKENDS), Some(3));
assert_eq!(read_gauge(names::configuration::FRONTENDS), Some(2));
}
}