use std::{net::SocketAddr, sync::Arc, time::Duration};
use ::jwt::VerifyWithKey;
use axum::{
extract::{
ws::{Message, WebSocket},
Query, State, WebSocketUpgrade,
},
http::HeaderMap,
middleware,
response::{IntoResponse, Response},
routing::get,
Extension, Router,
};
use futures_util::stream::StreamExt;
use http::StatusCode;
use serde::Deserialize;
use snops_common::{
constant::HEADER_AGENT_KEY,
prelude::*,
rpc::control::{
agent::{AgentServiceClient, Handshake},
ControlService,
},
};
use tarpc::server::Channel;
use tokio::select;
use tracing::{error, info, warn};
use self::{
error::StartError,
jwt::{Claims, JWT_SECRET},
rpc::ControlRpcServer,
};
use crate::{
logging::{log_request, req_stamp},
server::rpc::{MuxedMessageIncoming, MuxedMessageOutgoing},
state::{Agent, AgentFlags, AppState, GlobalState},
};
pub mod actions;
mod api;
mod content;
pub mod error;
pub mod jwt;
pub mod models;
pub mod prometheus;
mod rpc;
pub async fn start(state: Arc<GlobalState>, socket_addr: SocketAddr) -> Result<(), StartError> {
let app = Router::new()
.route("/agent", get(agent_ws_handler))
.nest("/api/v1", api::routes())
.nest("/prometheus", prometheus::routes())
.nest("/content", content::init_routes(&state).await)
.with_state(Arc::clone(&state))
.layer(Extension(state))
.layer(middleware::map_response(log_request))
.layer(middleware::from_fn(req_stamp));
let listener = tokio::net::TcpListener::bind(socket_addr)
.await
.map_err(StartError::TcpBind)?;
axum::serve(listener, app)
.await
.map_err(StartError::Serve)?;
Ok(())
}
#[derive(Debug, Deserialize)]
struct AgentWsQuery {
id: Option<AgentId>,
#[serde(flatten)]
flags: AgentFlags,
}
async fn agent_ws_handler(
ws: WebSocketUpgrade,
headers: HeaderMap,
State(state): State<AppState>,
Query(query): Query<AgentWsQuery>,
) -> Response {
match (&state.agent_key, headers.get(HEADER_AGENT_KEY)) {
(Some(key), Some(header)) if key == header.to_str().unwrap_or_default() => (),
(Some(_), _) => {
warn!("an agent has attempted to connect with a mismatching agent key");
return StatusCode::UNAUTHORIZED.into_response();
}
_ => (),
}
ws.on_upgrade(|socket| handle_socket(socket, headers, state, query))
.into_response()
}
async fn handle_socket(
mut socket: WebSocket,
headers: HeaderMap,
state: AppState,
query: AgentWsQuery,
) {
let claims = headers
.get("Authorization")
.and_then(|auth| -> Option<Claims> {
let auth = auth.to_str().ok()?;
if !auth.starts_with("Bearer ") {
return None;
}
let token = &auth[7..];
token.verify_with_key(&*JWT_SECRET).ok()
})
.filter(|claims| {
if let Some(id) = query.id {
if claims.id != id {
warn!("connecting agent specified an id different than the claim");
return false;
}
}
true
});
let (client_response_in, client_transport, mut client_request_out) = RpcTransport::new();
let (server_request_in, server_transport, mut server_response_out) = RpcTransport::new();
let client =
AgentServiceClient::new(tarpc::client::Config::default(), client_transport).spawn();
let id: AgentId = 'insertion: {
let client = client.clone();
let mut handshake = Handshake {
loki: state.cli.loki.as_ref().map(|u| u.to_string()),
..Default::default()
};
'reconnect: {
if let Some(claims) = claims {
let Some(mut agent) = state.pool.get_mut(&claims.id) else {
warn!("connecting agent is trying to identify as an unrecognized agent");
break 'reconnect;
};
let id = agent.id();
if agent.is_connected() {
warn!(
"connecting agent is trying to identify as an already-connected agent {id}"
);
break 'reconnect;
}
if agent.claims().nonce != claims.nonce {
warn!("connecting agent {id} is trying to identify with an invalid nonce");
break 'reconnect;
}
if let AgentState::Node(env, _) = agent.state() {
if !state.envs.contains_key(env) {
info!("setting agent {id} to Inventory state due to missing env {env}");
agent.set_state(AgentState::Inventory);
}
}
agent.state().clone_into(&mut handshake.state);
agent.mark_connected(client, query.flags);
info!("agent {id} reconnected");
if let Err(e) = state.db.agents.save(&id, &agent) {
error!("failed to save agent {id} to the database: {e}");
}
let client = agent.rpc().cloned().unwrap();
drop(agent);
tokio::spawn(async move {
let mut ctx = tarpc::context::current();
ctx.deadline += Duration::from_secs(300);
match client.handshake(ctx, handshake).await {
Ok(Ok(())) => (),
Ok(Err(e)) => {
error!("failed to perform agent {id} handshake reconciliation: {e}")
}
Err(e) => error!("failed to perform agent {id} handshake: {e}"),
}
});
break 'insertion id;
}
}
let id = query.id.unwrap_or_else(AgentId::rand);
if state
.pool
.get(&id)
.map(|a| a.is_connected())
.unwrap_or_default()
{
warn!("an agent is trying to identify as an already-connected agent {id}");
let _ = socket.send(Message::Close(None)).await;
return;
}
let agent = Agent::new(client.to_owned(), id, query.flags);
let signed_jwt = agent.sign_jwt();
handshake.jwt = Some(signed_jwt);
tokio::spawn(async move {
let mut ctx = tarpc::context::current();
ctx.deadline += Duration::from_secs(300);
match client.handshake(ctx, handshake).await {
Ok(Ok(())) => (),
Ok(Err(e)) => error!("failed to perform agent {id} handshake reconciliation: {e}"),
Err(e) => error!("failed to perform agent {id} handshake: {e}"),
}
});
if let Err(e) = state.db.agents.save(&id, &agent) {
error!("failed to save agent {id} to the database: {e}");
}
state.pool.insert(id, agent);
info!(
"agent {id} connected; pool is now {} nodes",
state.pool.len()
);
id
};
let state2 = Arc::clone(&state);
tokio::spawn(async move {
if let Ok((ports, external, internal)) = client.get_addrs(tarpc::context::current()).await {
if let Some(mut agent) = state2.pool.get_mut(&id) {
info!(
"agent {id} [{}], labels: {:?}, addrs: {external:?} {internal:?} @ {ports}, local pk: {}",
agent.modes(),
agent.str_labels(),
if agent.has_local_pk() { "yes" } else { "no" },
);
agent.set_ports(ports);
agent.set_addrs(external, internal);
if let Err(e) = state2.db.agents.save(&id, &agent) {
error!("failed to save agent {id} to the database: {e}");
}
}
}
});
let server = tarpc::server::BaseChannel::with_defaults(server_transport);
let server_handle = tokio::spawn(
server
.execute(
ControlRpcServer {
state: state.to_owned(),
agent: id,
}
.serve(),
)
.for_each(|r| async move {
tokio::spawn(r);
}),
);
loop {
select! {
msg = socket.recv() => {
match msg {
Some(Err(_)) | None => break,
Some(Ok(Message::Binary(bin))) => {
let msg = match bincode::deserialize(&bin) {
Ok(msg) => msg,
Err(e) => {
error!("failed to deserialize a message from agent {id}: {e}");
continue;
}
};
match msg {
MuxedMessageIncoming::Parent(msg) => server_request_in.send(msg).expect("internal RPC channel closed"),
MuxedMessageIncoming::Child(msg) => client_response_in.send(msg).expect("internal RPC channel closed"),
}
}
_ => (),
}
}
msg = client_request_out.recv() => {
let msg = msg.expect("internal RPC channel closed");
let bin = bincode::serialize(&MuxedMessageOutgoing::Child(msg)).expect("failed to serialize request");
if socket.send(Message::Binary(bin)).await.is_err() {
break;
}
}
msg = server_response_out.recv() => {
let msg = msg.expect("internal RPC channel closed");
let bin = bincode::serialize(&MuxedMessageOutgoing::Parent(msg)).expect("failed to serialize response");
if socket.send(Message::Binary(bin)).await.is_err() {
break;
}
}
}
}
server_handle.abort();
{
if let Some(mut agent) = state.pool.get_mut(&id) {
agent.mark_disconnected();
}
info!("agent {id} disconnected");
}
}