use busrt::broker::{Broker, ServerConfig};
use busrt::rpc::{RpcClient, RpcError, RpcEvent, RpcHandlers, RpcResult};
use busrt::{async_trait, cursors};
use futures::{Stream, TryStreamExt};
use serde::Serialize;
use sqlx::{
postgres::{PgPoolOptions, PgRow},
Row,
};
use std::pin::Pin;
use std::str::FromStr;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::time::sleep;
const CURSOR_TTL: Duration = Duration::from_secs(30);
type DbStream = Pin<Box<dyn Stream<Item = Result<PgRow, sqlx::Error>> + Send>>;
#[derive(Serialize)]
struct Customer {
id: i64,
name: String,
}
struct CustomerCursor {
stream: Mutex<DbStream>,
meta: cursors::Meta,
}
#[async_trait]
impl cursors::Cursor for CustomerCursor {
async fn next(&self) -> Result<Option<Vec<u8>>, RpcError> {
if let Some(row) = self
.stream
.lock()
.await
.try_next()
.await
.map_err(|_| RpcError::internal(None))?
{
let id: i64 = row.try_get(0).map_err(|_| RpcError::internal(None))?;
let name: String = row.try_get(1).map_err(|_| RpcError::internal(None))?;
Ok(Some(rmp_serde::to_vec_named(&Customer { id, name })?))
} else {
self.meta().mark_finished();
Ok(None)
}
}
async fn next_bulk(&self, count: usize) -> Result<Vec<u8>, RpcError> {
let mut result: Vec<Customer> = Vec::with_capacity(count);
if count > 0 {
let mut stream = self.stream.lock().await;
while let Some(row) = stream
.try_next()
.await
.map_err(|_| RpcError::internal(None))?
{
let id: i64 = row.try_get(0).map_err(|_| RpcError::internal(None))?;
let name: String = row.try_get(1).map_err(|_| RpcError::internal(None))?;
result.push(Customer { id, name });
if result.len() == count {
break;
}
}
}
if result.len() < count {
self.meta.mark_finished();
}
Ok(rmp_serde::to_vec_named(&result)?)
}
fn meta(&self) -> &cursors::Meta {
&self.meta
}
}
impl CustomerCursor {
fn new(stream: DbStream) -> Self {
Self {
stream: Mutex::new(stream),
meta: cursors::Meta::new(CURSOR_TTL),
}
}
}
struct MyHandlers {
pool: sqlx::PgPool,
cursors: cursors::Map,
}
#[async_trait]
impl RpcHandlers for MyHandlers {
async fn handle_call(&self, event: RpcEvent) -> RpcResult {
let payload = event.payload();
match event.parse_method()? {
"Ccustomers" => {
let stream = sqlx::query("select id, name from customers").fetch(&self.pool);
let cursor = CustomerCursor::new(stream);
let u = self.cursors.add(cursor).await;
Ok(Some(rmp_serde::to_vec_named(&cursors::Payload::from(u))?))
}
"N" => {
let p: cursors::Payload = rmp_serde::from_slice(payload)?;
self.cursors.next(p.uuid()).await
}
"NB" => {
let p: cursors::Payload = rmp_serde::from_slice(payload)?;
self.cursors.next_bulk(p.uuid(), p.bulk_number()).await
}
_ => Err(RpcError::method(None)),
}
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut broker = Broker::new();
broker
.spawn_unix_server("/tmp/busrt.sock", ServerConfig::default())
.await?;
let client = broker.register_client("db").await?;
let opts = sqlx::postgres::PgConnectOptions::from_str("postgres://tests:xxx@localhost/tests")?;
let pool = PgPoolOptions::new().connect_with(opts).await?;
let handlers = MyHandlers {
pool,
cursors: cursors::Map::new(Duration::from_secs(30)),
};
let _rpc = RpcClient::new(client, handlers);
loop {
sleep(Duration::from_secs(1)).await;
}
}