use std::collections::HashMap;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use futures::{Stream, StreamExt};
use surrealdb::method::QueryStream;
use surrealdb::types::SurrealValue;
use surrealdb::Notification;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use ulid::Ulid;
use crate::connection::client::DatabaseClient;
use crate::error::{Result, SurqlError};
pub struct LiveQuery<T> {
stream: QueryStream<Notification<T>>,
_marker: PhantomData<T>,
}
impl<T> std::fmt::Debug for LiveQuery<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LiveQuery").finish_non_exhaustive()
}
}
impl<T> LiveQuery<T>
where
T: SurrealValue + Unpin + 'static,
{
pub async fn start(client: &DatabaseClient, target: &str) -> Result<Self> {
let proto = client.config().protocol()?;
if !proto.supports_live_queries() {
return Err(SurqlError::Streaming {
reason: format!("live queries are not supported over {proto}"),
});
}
let surql = format!("LIVE SELECT * FROM {target};");
let mut response = client
.inner()
.query(surql)
.await
.map_err(|e| streaming_err(&e))?;
let stream: QueryStream<Notification<T>> =
response.stream(0).map_err(|e| streaming_err(&e))?;
Ok(Self {
stream,
_marker: PhantomData,
})
}
}
impl<T> Stream for LiveQuery<T>
where
T: SurrealValue + Unpin + 'static,
{
type Item = Result<Notification<T>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
Pin::new(&mut this.stream)
.poll_next(cx)
.map(|opt| opt.map(|res| res.map_err(|e| streaming_err(&e))))
}
}
fn streaming_err(err: &surrealdb::Error) -> SurqlError {
SurqlError::Streaming {
reason: err.to_string(),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct SubscriptionId(Ulid);
impl SubscriptionId {
fn new() -> Self {
Self(Ulid::new())
}
pub fn as_str(self) -> String {
self.0.to_string()
}
}
impl std::fmt::Display for SubscriptionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
pub struct StreamingManager {
inner: Arc<StreamingManagerInner>,
}
impl std::fmt::Debug for StreamingManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamingManager").finish_non_exhaustive()
}
}
impl Default for StreamingManager {
fn default() -> Self {
Self::new()
}
}
struct StreamingManagerInner {
tasks: Mutex<HashMap<SubscriptionId, JoinHandle<()>>>,
}
impl StreamingManager {
pub fn new() -> Self {
Self {
inner: Arc::new(StreamingManagerInner {
tasks: Mutex::new(HashMap::new()),
}),
}
}
pub async fn spawn<T, F>(
&self,
client: &DatabaseClient,
target: &str,
mut callback: F,
) -> Result<SubscriptionId>
where
T: SurrealValue + Unpin + Send + 'static,
F: FnMut(Notification<T>) + Send + 'static,
{
let mut live: LiveQuery<T> = LiveQuery::start(client, target).await?;
let id = SubscriptionId::new();
let handle = tokio::spawn(async move {
while let Some(item) = live.next().await {
match item {
Ok(n) => callback(n),
Err(err) => {
tracing::error!(
target = "surql::connection::streaming",
"live query error: {err}"
);
}
}
}
});
self.inner.tasks.lock().await.insert(id, handle);
Ok(id)
}
pub async fn kill(&self, id: SubscriptionId) -> bool {
if let Some(handle) = self.inner.tasks.lock().await.remove(&id) {
handle.abort();
let _ = handle.await;
true
} else {
false
}
}
pub async fn count(&self) -> usize {
self.inner.tasks.lock().await.len()
}
pub async fn ids(&self) -> Vec<SubscriptionId> {
self.inner.tasks.lock().await.keys().copied().collect()
}
pub async fn drain_all(&self) {
let handles: Vec<JoinHandle<()>> = {
let mut tasks = self.inner.tasks.lock().await;
tasks.drain().map(|(_, h)| h).collect()
};
for h in handles {
h.abort();
let _ = h.await;
}
}
}
impl Drop for StreamingManager {
fn drop(&mut self) {
if let Ok(mut tasks) = self.inner.tasks.try_lock() {
for (_, handle) in tasks.drain() {
handle.abort();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::connection::config::ConnectionConfig;
#[tokio::test]
async fn live_rejects_http_protocol() {
let cfg = ConnectionConfig::builder()
.url("http://localhost:8000")
.enable_live_queries(false)
.build()
.unwrap();
let client = DatabaseClient::new(cfg).unwrap();
let err = LiveQuery::<serde_json::Value>::start(&client, "user")
.await
.unwrap_err();
assert!(matches!(err, SurqlError::Streaming { .. }));
}
#[tokio::test]
async fn manager_starts_empty() {
let m = StreamingManager::new();
assert_eq!(m.count().await, 0);
assert!(m.ids().await.is_empty());
assert!(!m.kill(SubscriptionId::new()).await);
}
#[tokio::test]
async fn spawn_surfaces_live_query_errors() {
let cfg = ConnectionConfig::builder()
.url("http://localhost:8000")
.enable_live_queries(false)
.build()
.unwrap();
let client = DatabaseClient::new(cfg).unwrap();
let m = StreamingManager::new();
let err = m
.spawn::<serde_json::Value, _>(&client, "user", |_| {})
.await
.unwrap_err();
assert!(matches!(err, SurqlError::Streaming { .. }));
assert_eq!(m.count().await, 0);
}
#[tokio::test]
async fn drain_all_empties_pool() {
let m = StreamingManager::new();
m.drain_all().await;
assert_eq!(m.count().await, 0);
}
#[test]
fn subscription_id_is_unique() {
let a = SubscriptionId::new();
let b = SubscriptionId::new();
assert_ne!(a, b);
assert!(!a.to_string().is_empty());
}
}