1use atuin_client::database::Context;
2use atuin_client::settings::{FilterMode, Settings};
3use eyre::{Context as EyreContext, Result};
4#[cfg(windows)]
5use tokio::net::TcpStream;
6use tonic::Code;
7use tonic::transport::{Channel, Endpoint, Uri};
8use tower::service_fn;
9
10use hyper_util::rt::TokioIo;
11
12#[cfg(unix)]
13use tokio::net::UnixStream;
14
15use atuin_client::history::History;
16use tracing::{Level, instrument, span};
17
18use crate::control::HistoryRebuiltEvent;
19use crate::control::{
20 ForceSyncEvent, HistoryDeletedEvent, HistoryPrunedEvent, SendEventRequest,
21 SettingsReloadedEvent, ShutdownEvent, control_client::ControlClient as ControlServiceClient,
22};
23use crate::events::DaemonEvent;
24use crate::history::{
25 EndHistoryReply, EndHistoryRequest, ShutdownRequest, StartHistoryReply, StartHistoryRequest,
26 StatusReply, StatusRequest, TailHistoryReply, TailHistoryRequest,
27 history_client::HistoryClient as HistoryServiceClient,
28};
29use crate::search::{
30 FilterMode as RpcFilterMode, SearchContext as RpcSearchContext, SearchRequest, SearchResponse,
31 search_client::SearchClient as SearchServiceClient,
32};
33
34pub struct HistoryClient {
35 client: HistoryServiceClient<Channel>,
36}
37
38#[derive(Clone, Copy, Debug, Eq, PartialEq)]
39pub enum DaemonClientErrorKind {
40 Connect,
41 Unavailable,
42 Unimplemented,
43 Other,
44}
45
46#[must_use]
47pub fn classify_error(error: &eyre::Report) -> DaemonClientErrorKind {
48 for cause in error.chain() {
49 if cause.downcast_ref::<tonic::transport::Error>().is_some() {
50 return DaemonClientErrorKind::Connect;
51 }
52
53 if let Some(status) = cause.downcast_ref::<tonic::Status>() {
54 return match status.code() {
55 Code::Unavailable => DaemonClientErrorKind::Unavailable,
56 Code::Unimplemented => DaemonClientErrorKind::Unimplemented,
57 _ => DaemonClientErrorKind::Other,
58 };
59 }
60 }
61
62 DaemonClientErrorKind::Other
63}
64
65impl HistoryClient {
67 #[cfg(unix)]
68 pub async fn new(path: String) -> Result<Self> {
69 use eyre::Context;
70
71 let log_path = path.clone();
72 let channel = Endpoint::try_from("http://atuin_local_daemon:0")?
73 .connect_with_connector(service_fn(move |_: Uri| {
74 let path = path.clone();
75
76 async move {
77 Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path.clone()).await?))
78 }
79 }))
80 .await
81 .wrap_err_with(|| {
82 format!(
83 "failed to connect to local atuin daemon at {}. Is it running?",
84 &log_path
85 )
86 })?;
87
88 let client = HistoryServiceClient::new(channel);
89
90 Ok(HistoryClient { client })
91 }
92
93 #[cfg(not(unix))]
94 pub async fn new(port: u64) -> Result<Self> {
95 let channel = Endpoint::try_from("http://atuin_local_daemon:0")?
96 .connect_with_connector(service_fn(move |_: Uri| {
97 let url = format!("127.0.0.1:{port}");
98
99 async move {
100 Ok::<_, std::io::Error>(TokioIo::new(TcpStream::connect(url.clone()).await?))
101 }
102 }))
103 .await
104 .wrap_err_with(|| {
105 format!(
106 "failed to connect to local atuin daemon at 127.0.0.1:{port}. Is it running?"
107 )
108 })?;
109
110 let client = HistoryServiceClient::new(channel);
111
112 Ok(HistoryClient { client })
113 }
114
115 pub async fn start_history(&mut self, h: History) -> Result<StartHistoryReply> {
116 let req = StartHistoryRequest {
117 command: h.command,
118 cwd: h.cwd,
119 hostname: h.hostname,
120 session: h.session,
121 timestamp: h.timestamp.unix_timestamp_nanos() as u64,
122 author: h.author,
123 intent: h.intent.unwrap_or_default(),
124 };
125
126 Ok(self.client.start_history(req).await?.into_inner())
127 }
128
129 pub async fn end_history(
130 &mut self,
131 id: String,
132 duration: u64,
133 exit: i64,
134 ) -> Result<EndHistoryReply> {
135 let req = EndHistoryRequest { id, duration, exit };
136
137 Ok(self.client.end_history(req).await?.into_inner())
138 }
139
140 pub async fn status(&mut self) -> Result<StatusReply> {
141 Ok(self.client.status(StatusRequest {}).await?.into_inner())
142 }
143
144 pub async fn tail_history(&mut self) -> Result<tonic::Streaming<TailHistoryReply>> {
145 Ok(self
146 .client
147 .tail_history(TailHistoryRequest {})
148 .await?
149 .into_inner())
150 }
151
152 pub async fn shutdown(&mut self) -> Result<bool> {
153 let resp = self.client.shutdown(ShutdownRequest {}).await?.into_inner();
154 Ok(resp.accepted)
155 }
156}
157
158pub struct SearchClient {
159 client: SearchServiceClient<Channel>,
160}
161
162impl SearchClient {
163 #[cfg(unix)]
164 pub async fn new(path: String) -> Result<Self> {
165 let log_path = path.clone();
166 let channel = Endpoint::try_from("http://atuin_local_daemon:0")?
167 .connect_with_connector(service_fn(move |_: Uri| {
168 let path = path.clone();
169
170 async move {
171 Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path.clone()).await?))
172 }
173 }))
174 .await
175 .wrap_err_with(|| {
176 format!(
177 "failed to connect to local atuin daemon at {}. Is it running?",
178 &log_path
179 )
180 })?;
181
182 let client = SearchServiceClient::new(channel);
183
184 Ok(SearchClient { client })
185 }
186
187 #[cfg(not(unix))]
188 pub async fn new(port: u64) -> Result<Self> {
189 let channel = Endpoint::try_from("http://atuin_local_daemon:0")?
190 .connect_with_connector(service_fn(move |_: Uri| {
191 let url = format!("127.0.0.1:{port}");
192
193 async move {
194 Ok::<_, std::io::Error>(TokioIo::new(TcpStream::connect(url.clone()).await?))
195 }
196 }))
197 .await
198 .wrap_err_with(|| {
199 format!(
200 "failed to connect to local atuin daemon at 127.0.0.1:{port}. Is it running?"
201 )
202 })?;
203
204 let client = SearchServiceClient::new(channel);
205
206 Ok(SearchClient { client })
207 }
208
209 #[instrument(skip_all, level = Level::TRACE, name = "daemon_client_search", fields(query = %query, query_id = query_id))]
210 pub async fn search(
211 &mut self,
212 query: String,
213 query_id: u64,
214 filter_mode: FilterMode,
215 context: Option<Context>,
216 ) -> Result<tonic::Streaming<SearchResponse>> {
217 let request = SearchRequest {
218 query,
219 query_id,
220 filter_mode: RpcFilterMode::from(filter_mode).into(),
221 context: context.map(RpcSearchContext::from),
222 };
223 let request_stream = tokio_stream::once(request);
224 let response = span!(Level::TRACE, "daemon_client_search.request")
225 .in_scope(async || self.client.search(request_stream).await)
226 .await?;
227
228 Ok(response.into_inner())
229 }
230}
231
232impl From<FilterMode> for RpcFilterMode {
233 fn from(filter_mode: FilterMode) -> Self {
234 match filter_mode {
235 FilterMode::Global => RpcFilterMode::Global,
236 FilterMode::Host => RpcFilterMode::Host,
237 FilterMode::Session => RpcFilterMode::Session,
238 FilterMode::Directory => RpcFilterMode::Directory,
239 FilterMode::Workspace => RpcFilterMode::Workspace,
240 FilterMode::SessionPreload => RpcFilterMode::SessionPreload,
241 }
242 }
243}
244
245impl From<Context> for RpcSearchContext {
246 fn from(context: Context) -> Self {
247 RpcSearchContext {
248 session_id: context.session,
249 cwd: context.cwd,
250 hostname: context.hostname,
251 host_id: context.host_id,
252 git_root: context
253 .git_root
254 .map(|path| path.to_string_lossy().to_string()),
255 }
256 }
257}
258
259pub struct ControlClient {
267 client: ControlServiceClient<Channel>,
268}
269
270impl ControlClient {
271 #[cfg(unix)]
273 pub async fn new(path: String) -> Result<Self> {
274 let log_path = path.clone();
275 let channel = Endpoint::try_from("http://atuin_local_daemon:0")?
276 .connect_with_connector(service_fn(move |_: Uri| {
277 let path = path.clone();
278
279 async move {
280 Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path.clone()).await?))
281 }
282 }))
283 .await
284 .wrap_err_with(|| {
285 format!(
286 "failed to connect to local atuin daemon at {}. Is it running?",
287 &log_path
288 )
289 })?;
290
291 let client = ControlServiceClient::new(channel);
292
293 Ok(ControlClient { client })
294 }
295
296 #[cfg(not(unix))]
298 pub async fn new(port: u64) -> Result<Self> {
299 let channel = Endpoint::try_from("http://atuin_local_daemon:0")?
300 .connect_with_connector(service_fn(move |_: Uri| {
301 let url = format!("127.0.0.1:{port}");
302
303 async move {
304 Ok::<_, std::io::Error>(TokioIo::new(TcpStream::connect(url.clone()).await?))
305 }
306 }))
307 .await
308 .wrap_err_with(|| {
309 format!(
310 "failed to connect to local atuin daemon at 127.0.0.1:{port}. Is it running?"
311 )
312 })?;
313
314 let client = ControlServiceClient::new(channel);
315
316 Ok(ControlClient { client })
317 }
318
319 #[cfg(unix)]
321 pub async fn from_settings(settings: &Settings) -> Result<Self> {
322 Self::new(settings.daemon.socket_path.clone()).await
323 }
324
325 #[cfg(not(unix))]
327 pub async fn from_settings(settings: &Settings) -> Result<Self> {
328 Self::new(settings.daemon.tcp_port).await
329 }
330
331 pub async fn send_event(&mut self, event: DaemonEvent) -> Result<()> {
333 let proto_event = daemon_event_to_proto(event);
334 let request = SendEventRequest {
335 event: Some(proto_event),
336 };
337 self.client.send_event(request).await?;
338 Ok(())
339 }
340}
341
342fn daemon_event_to_proto(event: DaemonEvent) -> crate::control::send_event_request::Event {
344 use crate::control::send_event_request::Event;
345
346 match event {
347 DaemonEvent::HistoryPruned => Event::HistoryPruned(HistoryPrunedEvent {}),
348 DaemonEvent::HistoryRebuilt => Event::HistoryRebuilt(HistoryRebuiltEvent {}),
349 DaemonEvent::HistoryDeleted { ids } => Event::HistoryDeleted(HistoryDeletedEvent {
350 ids: ids.into_iter().map(|id| id.0).collect(),
351 }),
352 DaemonEvent::ForceSync => Event::ForceSync(ForceSyncEvent {}),
353 DaemonEvent::SettingsReloaded => Event::SettingsReloaded(SettingsReloadedEvent {}),
354 DaemonEvent::ShutdownRequested => Event::Shutdown(ShutdownEvent {}),
355 DaemonEvent::HistoryStarted(_)
357 | DaemonEvent::HistoryEnded(_)
358 | DaemonEvent::RecordsAdded(_)
359 | DaemonEvent::SyncCompleted { .. }
360 | DaemonEvent::SyncFailed { .. } => {
361 tracing::warn!("attempted to send internal event via control service");
363 Event::Shutdown(ShutdownEvent {})
364 }
365 }
366}
367
368pub async fn emit_event(event: DaemonEvent) -> Result<()> {
391 emit_event_with_settings(event, None).await
392}
393
394pub async fn emit_event_with_settings(
399 event: DaemonEvent,
400 settings: Option<&Settings>,
401) -> Result<()> {
402 let owned_settings;
404 let settings = match settings {
405 Some(s) => s,
406 None => {
407 owned_settings = Settings::new()?;
408 &owned_settings
409 }
410 };
411
412 let mut client = match ControlClient::from_settings(settings).await {
414 Ok(c) => c,
415 Err(e) => {
416 tracing::debug!(?e, "daemon not running, skipping event emission");
417 return Ok(());
418 }
419 };
420
421 if let Err(e) = client.send_event(event).await {
423 tracing::debug!(?e, "failed to send event to daemon");
424 }
426
427 Ok(())
428}