sos_database/
audit_provider.rs1use crate::{
3 entity::{AuditEntity, AuditRecord, AuditRow},
4 Error,
5};
6use async_sqlite::Client;
7use async_trait::async_trait;
8use futures::stream::BoxStream;
9use sos_audit::{AuditEvent, AuditStreamSink};
10use tokio_stream::wrappers::ReceiverStream;
11
12pub struct AuditDatabaseProvider<E>
14where
15 E: std::error::Error
16 + std::fmt::Debug
17 + From<crate::Error>
18 + Send
19 + Sync
20 + 'static,
21{
22 client: Client,
23 marker: std::marker::PhantomData<E>,
24}
25
26impl<E> AuditDatabaseProvider<E>
27where
28 E: std::error::Error
29 + std::fmt::Debug
30 + From<crate::Error>
31 + Send
32 + Sync
33 + 'static,
34{
35 pub fn new(client: Client) -> Self {
37 Self {
38 client,
39 marker: std::marker::PhantomData,
40 }
41 }
42}
43
44#[async_trait]
45impl<E> AuditStreamSink for AuditDatabaseProvider<E>
46where
47 E: std::error::Error
48 + std::fmt::Debug
49 + From<crate::Error>
50 + From<std::io::Error>
51 + Send
52 + Sync
53 + 'static,
54{
55 type Error = E;
56
57 async fn append_audit_events(
58 &self,
59 events: &[AuditEvent],
60 ) -> std::result::Result<(), Self::Error> {
61 let mut audit_events = Vec::new();
62 for event in events {
63 audit_events.push(event.try_into()?);
64 }
65 self.client
66 .conn(move |conn| {
67 let audit = AuditEntity::new(&conn);
68 audit.insert_audit_logs(audit_events.as_slice())?;
69 Ok(())
70 })
71 .await
72 .map_err(Error::from)?;
73 Ok(())
74 }
75
76 async fn audit_stream(
77 &self,
78 reverse: bool,
79 ) -> std::result::Result<
80 BoxStream<'static, std::result::Result<AuditEvent, Self::Error>>,
81 Self::Error,
82 > {
83 let (tx, rx) = tokio::sync::mpsc::channel::<
84 std::result::Result<AuditEvent, Self::Error>,
85 >(16);
86
87 let client = self.client.clone();
88 tokio::task::spawn(async move {
89 client
90 .conn_and_then(move |conn| {
91 let mut stmt = if reverse {
92 conn.prepare(
93 "SELECT * FROM audit_logs ORDER BY log_id DESC",
94 )?
95 } else {
96 conn.prepare(
97 "SELECT * FROM audit_logs ORDER BY log_id ASC",
98 )?
99 };
100 let mut rows = stmt.query([])?;
101
102 while let Some(row) = rows.next()? {
103 if tx.is_closed() {
104 break;
105 }
106 let row: AuditRow = row.try_into()?;
107 let record: AuditRecord = row.try_into()?;
108 let inner_tx = tx.clone();
109 let res = futures::executor::block_on(async move {
110 inner_tx.send(Ok(record.event)).await
111 });
112 if let Err(e) = res {
113 tracing::error!(error = %e);
114 break;
115 }
116 }
117
118 Ok::<_, Error>(())
119 })
120 .await?;
121 Ok::<_, Self::Error>(())
122 });
123
124 Ok(Box::pin(ReceiverStream::new(rx)))
125 }
126}