#![allow(
unsafe_code,
reason = "Streaming callbacks require unsafe for FFI trampoline functions and raw pointer casts"
)]
use std::ffi::{CStr, c_void};
use std::sync::mpsc;
use std::{fmt, ptr};
use crate::client::Client;
use crate::conversation::Conversation;
use crate::error::{self, Result};
use crate::ffi::{OwnedHandle, to_ffi_len};
use crate::types::{ConsentEntityType, ConsentState, ConversationType, PreferenceKind};
pub struct Subscription<T> {
rx: mpsc::Receiver<T>,
handle: OwnedHandle<xmtp_sys::XmtpFfiStreamHandle>,
_ctx: Option<Box<dyn std::any::Any + Send>>,
}
impl<T> Subscription<T> {
#[must_use]
pub fn recv(&self) -> Option<T> {
self.rx.recv().ok()
}
#[must_use]
pub fn try_recv(&self) -> Option<T> {
self.rx.try_recv().ok()
}
pub fn close(&self) {
unsafe { xmtp_sys::xmtp_stream_end(self.handle.as_ptr()) };
}
#[must_use]
pub fn is_closed(&self) -> bool {
unsafe { xmtp_sys::xmtp_stream_is_closed(self.handle.as_ptr()) == 1 }
}
}
impl<T> Iterator for Subscription<T> {
type Item = T;
fn next(&mut self) -> Option<T> {
self.rx.recv().ok()
}
}
impl<T> Drop for Subscription<T> {
fn drop(&mut self) {
unsafe { xmtp_sys::xmtp_stream_end(self.handle.as_ptr()) };
}
}
impl<T> fmt::Debug for Subscription<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Subscription")
.field("is_closed", &self.is_closed())
.finish()
}
}
#[derive(Debug, Clone)]
pub struct MessageEvent {
pub message_id: String,
pub conversation_id: String,
}
#[derive(Debug, Clone)]
pub struct ConsentUpdate {
pub entity_type: ConsentEntityType,
pub state: ConsentState,
pub entity: String,
}
#[derive(Debug, Clone)]
pub struct PreferenceUpdate {
pub kind: PreferenceKind,
pub consent: Option<ConsentUpdate>,
}
fn subscribe<T: Send + 'static, F: Send + 'static>(
callback: F,
rx: mpsc::Receiver<T>,
start: impl FnOnce(*mut c_void, *mut *mut xmtp_sys::XmtpFfiStreamHandle) -> i32,
) -> Result<Subscription<T>> {
let boxed = Box::new(callback);
let ctx_ptr = Box::into_raw(boxed).cast::<c_void>();
let mut out: *mut xmtp_sys::XmtpFfiStreamHandle = ptr::null_mut();
let rc = start(ctx_ptr, &raw mut out);
if rc != 0 {
let _ = unsafe { Box::from_raw(ctx_ptr.cast::<F>()) };
return Err(error::last_ffi_error());
}
let handle = OwnedHandle::new(out, xmtp_sys::xmtp_stream_free)?;
let ctx_box = unsafe { Box::from_raw(ctx_ptr.cast::<F>()) };
Ok(Subscription {
rx,
handle,
_ctx: Some(ctx_box),
})
}
pub fn conversations(
client: &Client,
conversation_type: Option<ConversationType>,
) -> Result<Subscription<Conversation>> {
let (tx, rx) = mpsc::channel();
let client_ptr = client.handle.as_ptr();
let conv_type = conversation_type.map_or(-1, |t| t as i32);
let cb: Box<dyn Fn(Conversation) + Send> = Box::new(move |conv| {
drop(tx.send(conv));
});
subscribe(cb, rx, |ctx, out| unsafe {
xmtp_sys::xmtp_stream_conversations(
client_ptr,
conv_type,
Some(conv_trampoline),
None,
ctx,
out,
)
})
}
pub fn messages(
client: &Client,
conversation_type: Option<ConversationType>,
consent_states: &[ConsentState],
) -> Result<Subscription<MessageEvent>> {
let (tx, rx) = mpsc::channel();
let client_ptr = client.handle.as_ptr();
let conv_type = conversation_type.map_or(-1, |t| t as i32);
let cs: Vec<i32> = consent_states.iter().map(|s| *s as i32).collect();
let cs_ptr = if cs.is_empty() {
ptr::null()
} else {
cs.as_ptr()
};
let cs_len = to_ffi_len(cs.len())?;
let cb: Box<dyn Fn(String, String) + Send> = Box::new(move |mid, cid| {
drop(tx.send(MessageEvent {
message_id: mid,
conversation_id: cid,
}));
});
subscribe(cb, rx, |ctx, out| unsafe {
xmtp_sys::xmtp_stream_all_messages(
client_ptr,
conv_type,
cs_ptr,
cs_len,
Some(msg_trampoline),
None,
ctx,
out,
)
})
}
pub fn conversation_messages(conversation: &Conversation) -> Result<Subscription<MessageEvent>> {
let (tx, rx) = mpsc::channel();
let conv_ptr = conversation.handle_ptr();
let cb: Box<dyn Fn(String, String) + Send> = Box::new(move |mid, cid| {
drop(tx.send(MessageEvent {
message_id: mid,
conversation_id: cid,
}));
});
subscribe(cb, rx, |ctx, out| unsafe {
xmtp_sys::xmtp_conversation_stream_messages(conv_ptr, Some(msg_trampoline), None, ctx, out)
})
}
pub fn consent(client: &Client) -> Result<Subscription<Vec<ConsentUpdate>>> {
let (tx, rx) = mpsc::channel();
let client_ptr = client.handle.as_ptr();
let cb: Box<dyn Fn(Vec<ConsentUpdate>) + Send> = Box::new(move |updates| {
drop(tx.send(updates));
});
subscribe(cb, rx, |ctx, out| unsafe {
xmtp_sys::xmtp_stream_consent(client_ptr, Some(consent_trampoline), None, ctx, out)
})
}
pub fn preferences(client: &Client) -> Result<Subscription<Vec<PreferenceUpdate>>> {
let (tx, rx) = mpsc::channel();
let client_ptr = client.handle.as_ptr();
let cb: Box<dyn Fn(Vec<PreferenceUpdate>) + Send> = Box::new(move |updates| {
drop(tx.send(updates));
});
subscribe(cb, rx, |ctx, out| unsafe {
xmtp_sys::xmtp_stream_preferences(client_ptr, Some(pref_trampoline), None, ctx, out)
})
}
pub fn message_deletions(client: &Client) -> Result<Subscription<String>> {
let (tx, rx) = mpsc::channel();
let client_ptr = client.handle.as_ptr();
let cb: Box<dyn Fn(String) + Send> = Box::new(move |id| {
drop(tx.send(id));
});
subscribe(cb, rx, |ctx, out| unsafe {
xmtp_sys::xmtp_stream_message_deletions(
client_ptr,
Some(deletion_trampoline),
None,
ctx,
out,
)
})
}
unsafe extern "C" fn conv_trampoline(
conv: *mut xmtp_sys::XmtpFfiConversation,
context: *mut c_void,
) {
if context.is_null() || conv.is_null() {
return;
}
let cb = unsafe { &*context.cast::<Box<dyn Fn(Conversation) + Send>>() };
if let Ok(c) = Conversation::from_raw(conv) {
cb(c);
}
}
unsafe extern "C" fn msg_trampoline(msg: *mut xmtp_sys::XmtpFfiMessage, context: *mut c_void) {
if context.is_null() || msg.is_null() {
if !msg.is_null() {
unsafe { xmtp_sys::xmtp_message_free(msg) };
}
return;
}
let id_ptr = unsafe { xmtp_sys::xmtp_single_message_id(msg) };
let gid_ptr = unsafe { xmtp_sys::xmtp_single_message_group_id(msg) };
unsafe { xmtp_sys::xmtp_message_free(msg) };
let cb = unsafe { &*context.cast::<Box<dyn Fn(String, String) + Send>>() };
let id = if id_ptr.is_null() {
String::new()
} else {
let s = unsafe { CStr::from_ptr(id_ptr) }
.to_str()
.unwrap_or_default()
.to_owned();
unsafe { xmtp_sys::xmtp_free_string(id_ptr) };
s
};
let gid = if gid_ptr.is_null() {
String::new()
} else {
let s = unsafe { CStr::from_ptr(gid_ptr) }
.to_str()
.unwrap_or_default()
.to_owned();
unsafe { xmtp_sys::xmtp_free_string(gid_ptr) };
s
};
cb(id, gid);
}
unsafe extern "C" fn consent_trampoline(
records: *const xmtp_sys::XmtpFfiConsentRecord,
count: i32,
context: *mut c_void,
) {
if context.is_null() || records.is_null() || count <= 0 {
return;
}
let cb = unsafe { &*context.cast::<Box<dyn Fn(Vec<ConsentUpdate>) + Send>>() };
let slice = unsafe { std::slice::from_raw_parts(records, count.unsigned_abs() as usize) };
let updates: Vec<ConsentUpdate> = slice
.iter()
.filter_map(|r| {
let entity_type = ConsentEntityType::from_ffi(r.entity_type as i32)?;
let state = ConsentState::from_ffi(r.state as i32)?;
let entity = unsafe { CStr::from_ptr(r.entity) }
.to_str()
.ok()?
.to_owned();
Some(ConsentUpdate {
entity_type,
state,
entity,
})
})
.collect();
if !updates.is_empty() {
cb(updates);
}
}
unsafe extern "C" fn pref_trampoline(
updates: *const xmtp_sys::XmtpFfiPreferenceUpdate,
count: i32,
context: *mut c_void,
) {
if context.is_null() || updates.is_null() || count <= 0 {
return;
}
let cb = unsafe { &*context.cast::<Box<dyn Fn(Vec<PreferenceUpdate>) + Send>>() };
let slice = unsafe { std::slice::from_raw_parts(updates, count.unsigned_abs() as usize) };
let items: Vec<PreferenceUpdate> = slice
.iter()
.filter_map(|u| {
let kind = PreferenceKind::from_ffi(u.kind as i32)?;
let consent = if kind == PreferenceKind::Consent {
let r = &u.consent;
let et = ConsentEntityType::from_ffi(r.entity_type as i32);
let st = ConsentState::from_ffi(r.state as i32);
let entity = if r.entity.is_null() {
String::new()
} else {
unsafe { CStr::from_ptr(r.entity) }
.to_str()
.unwrap_or_default()
.to_owned()
};
et.zip(st).map(|(entity_type, state)| ConsentUpdate {
entity_type,
state,
entity,
})
} else {
None
};
Some(PreferenceUpdate { kind, consent })
})
.collect();
if !items.is_empty() {
cb(items);
}
}
unsafe extern "C" fn deletion_trampoline(
message_id: *const std::ffi::c_char,
context: *mut c_void,
) {
if context.is_null() || message_id.is_null() {
return;
}
let cb = unsafe { &*context.cast::<Box<dyn Fn(String) + Send>>() };
if let Ok(id) = unsafe { CStr::from_ptr(message_id) }.to_str() {
cb(id.to_owned());
}
}