use std::{pin::Pin, sync::Arc};
use atuin_client::{
database::Database,
history::{History, HistoryId, store::HistoryStore},
settings::Settings,
};
use dashmap::DashMap;
use eyre::Result;
use time::OffsetDateTime;
use tokio_stream::Stream;
use tonic::{Request, Response, Status};
use tracing::{Level, instrument};
use crate::{
daemon::{Component, DaemonHandle},
events::DaemonEvent,
history::{
EndHistoryReply, EndHistoryRequest, HistoryEntry, HistoryEventKind, ShutdownReply,
ShutdownRequest, StartHistoryReply, StartHistoryRequest, StatusReply, StatusRequest,
TailHistoryReply, TailHistoryRequest,
history_server::{History as HistorySvc, HistoryServer},
},
};
const DAEMON_PROTOCOL_VERSION: u32 = 1;
pub struct HistoryComponent {
inner: Arc<HistoryComponentInner>,
}
struct HistoryComponentInner {
running: DashMap<HistoryId, History>,
handle: tokio::sync::RwLock<Option<DaemonHandle>>,
history_store: tokio::sync::RwLock<Option<HistoryStore>>,
}
impl HistoryComponent {
pub fn new() -> Self {
Self {
inner: Arc::new(HistoryComponentInner {
running: DashMap::new(),
handle: tokio::sync::RwLock::new(None),
history_store: tokio::sync::RwLock::new(None),
}),
}
}
pub fn grpc_service(&self) -> HistoryServer<HistoryGrpcService> {
HistoryServer::new(HistoryGrpcService {
inner: self.inner.clone(),
})
}
}
impl Default for HistoryComponent {
fn default() -> Self {
Self::new()
}
}
#[tonic::async_trait]
impl Component for HistoryComponent {
fn name(&self) -> &'static str {
"history"
}
async fn start(&mut self, handle: DaemonHandle) -> Result<()> {
let host_id = Settings::host_id().await?;
let history_store =
HistoryStore::new(handle.store().clone(), host_id, *handle.encryption_key());
*self.inner.history_store.write().await = Some(history_store);
*self.inner.handle.write().await = Some(handle);
tracing::info!("history component started");
Ok(())
}
async fn handle_event(&mut self, _event: &DaemonEvent) -> Result<()> {
Ok(())
}
async fn stop(&mut self) -> Result<()> {
tracing::info!("history component stopped");
Ok(())
}
}
pub struct HistoryGrpcService {
inner: Arc<HistoryComponentInner>,
}
fn history_to_tail_reply(kind: HistoryEventKind, history: History) -> TailHistoryReply {
TailHistoryReply {
kind: kind as i32,
history: Some(HistoryEntry {
timestamp: history.timestamp.unix_timestamp_nanos() as u64,
id: history.id.0,
command: history.command,
cwd: history.cwd,
session: history.session,
hostname: history.hostname,
author: history.author,
intent: history.intent.unwrap_or_default(),
exit: history.exit,
duration: history.duration,
}),
}
}
#[tonic::async_trait]
impl HistorySvc for HistoryGrpcService {
type TailHistoryStream = Pin<Box<dyn Stream<Item = Result<TailHistoryReply, Status>> + Send>>;
#[instrument(skip_all, level = Level::INFO)]
async fn start_history(
&self,
request: Request<StartHistoryRequest>,
) -> Result<Response<StartHistoryReply>, Status> {
let req = request.into_inner();
let timestamp =
OffsetDateTime::from_unix_timestamp_nanos(req.timestamp as i128).map_err(|_| {
Status::invalid_argument(
"failed to parse timestamp as unix time (expected nanos since epoch)",
)
})?;
let h: History = History::daemon()
.timestamp(timestamp)
.command(req.command)
.cwd(req.cwd)
.session(req.session)
.hostname(req.hostname)
.author(req.author)
.intent(req.intent)
.build()
.into();
if let Some(handle) = self.inner.handle.read().await.as_ref() {
handle.emit(DaemonEvent::HistoryStarted(h.clone()));
}
let id = h.id.clone();
tracing::info!(id = id.to_string(), "start history");
self.inner.running.insert(id.clone(), h);
let reply = StartHistoryReply {
id: id.to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
protocol: DAEMON_PROTOCOL_VERSION,
};
Ok(Response::new(reply))
}
#[instrument(skip_all, level = Level::INFO)]
async fn end_history(
&self,
request: Request<EndHistoryRequest>,
) -> Result<Response<EndHistoryReply>, Status> {
let req = request.into_inner();
let id = HistoryId(req.id);
if let Some((_, mut history)) = self.inner.running.remove(&id) {
history.exit = req.exit;
history.duration = match req.duration {
0 => i64::try_from(
(OffsetDateTime::now_utc() - history.timestamp).whole_nanoseconds(),
)
.expect("failed to convert calculated duration to i64"),
value => i64::try_from(value).expect("failed to get i64 duration"),
};
let handle_guard = self.inner.handle.read().await;
let handle = handle_guard
.as_ref()
.ok_or_else(|| Status::internal("component not initialized"))?;
let store_guard = self.inner.history_store.read().await;
let history_store = store_guard
.as_ref()
.ok_or_else(|| Status::internal("component not initialized"))?;
handle
.history_db()
.save(&history)
.await
.map_err(|e| Status::internal(format!("failed to write to db: {e:?}")))?;
tracing::info!(
id = id.0.to_string(),
duration = history.duration,
"end history"
);
let (record_id, idx) = history_store
.push(history.clone())
.await
.map_err(|e| Status::internal(format!("failed to push record to store: {e:?}")))?;
handle.emit(DaemonEvent::HistoryEnded(history));
let reply = EndHistoryReply {
id: record_id.0.to_string(),
idx,
version: env!("CARGO_PKG_VERSION").to_string(),
protocol: DAEMON_PROTOCOL_VERSION,
};
return Ok(Response::new(reply));
}
Err(Status::not_found(format!(
"could not find history with id: {id}"
)))
}
#[instrument(skip_all, level = Level::INFO)]
async fn tail_history(
&self,
_request: Request<TailHistoryRequest>,
) -> Result<Response<Self::TailHistoryStream>, Status> {
let handle_guard = self.inner.handle.read().await;
let handle = handle_guard
.as_ref()
.cloned()
.ok_or_else(|| Status::internal("component not initialized"))?;
let mut rx = handle.subscribe();
let (tx, out_rx) = tokio::sync::mpsc::channel::<Result<TailHistoryReply, Status>>(128);
tokio::spawn(async move {
loop {
let event = match rx.recv().await {
Ok(event) => event,
Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => {
let _ = tx
.send(Err(Status::resource_exhausted(format!(
"tail stream lagged behind and dropped {skipped} events"
))))
.await;
break;
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
};
let reply = match event {
DaemonEvent::HistoryStarted(history) => {
Some(history_to_tail_reply(HistoryEventKind::Started, history))
}
DaemonEvent::HistoryEnded(history) => {
Some(history_to_tail_reply(HistoryEventKind::Ended, history))
}
_ => None,
};
if let Some(reply) = reply
&& tx.send(Ok(reply)).await.is_err()
{
break;
}
}
});
let stream = tokio_stream::wrappers::ReceiverStream::new(out_rx);
Ok(Response::new(Box::pin(stream)))
}
#[instrument(skip_all, level = Level::INFO)]
async fn status(
&self,
_request: Request<StatusRequest>,
) -> Result<Response<StatusReply>, Status> {
let reply = StatusReply {
healthy: true,
version: env!("CARGO_PKG_VERSION").to_string(),
pid: std::process::id(),
protocol: DAEMON_PROTOCOL_VERSION,
};
Ok(Response::new(reply))
}
#[instrument(skip_all, level = Level::INFO)]
async fn shutdown(
&self,
_request: Request<ShutdownRequest>,
) -> Result<Response<ShutdownReply>, Status> {
if let Some(handle) = self.inner.handle.read().await.as_ref() {
handle.shutdown();
}
Ok(Response::new(ShutdownReply { accepted: true }))
}
}