Skip to main content

rustex_runtime/
lib.rs

1use std::fmt as stdfmt;
2use std::{collections::BTreeMap, marker::PhantomData, pin::Pin, task};
3
4use convex::{
5    ConvexClient, FunctionResult, QuerySetSubscription, QuerySubscription, SubscriberId, Value,
6};
7use futures_core::Stream;
8use serde::{Serialize, de::DeserializeOwned};
9use thiserror::Error;
10use time::macros::format_description;
11use tracing::{Instrument, debug, trace};
12use tracing_subscriber::fmt::format::Writer;
13use tracing_subscriber::fmt::time::{FormatTime, UtcTime};
14use tracing_subscriber::fmt::{FmtContext, FormatEvent, FormatFields};
15use tracing_subscriber::registry::LookupSpan;
16use tracing_subscriber::{EnvFilter, fmt};
17
18pub trait FunctionSpec {
19    type Args: Serialize;
20    type Output: DeserializeOwned;
21
22    const PATH: &'static str;
23}
24
25pub trait QuerySpec: FunctionSpec {}
26pub trait MutationSpec: FunctionSpec {}
27pub trait ActionSpec: FunctionSpec {}
28
29pub struct TypedSubscription<F> {
30    inner: QuerySubscription,
31    marker: PhantomData<fn() -> F>,
32}
33
34impl<F> TypedSubscription<F> {
35    pub fn from_inner(inner: QuerySubscription) -> Self {
36        Self {
37            inner,
38            marker: PhantomData,
39        }
40    }
41
42    pub fn id(&self) -> &SubscriberId {
43        self.inner.id()
44    }
45
46    pub fn inner(&self) -> &QuerySubscription {
47        &self.inner
48    }
49
50    pub fn inner_mut(&mut self) -> &mut QuerySubscription {
51        &mut self.inner
52    }
53
54    pub fn into_inner(self) -> QuerySubscription {
55        self.inner
56    }
57}
58
59impl<F> std::fmt::Debug for TypedSubscription<F> {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        f.debug_struct("TypedSubscription")
62            .field("subscriber_id", self.id())
63            .finish()
64    }
65}
66
67impl<F> Stream for TypedSubscription<F>
68where
69    F: QuerySpec,
70{
71    type Item = Result<F::Output, RuntimeError>;
72
73    fn poll_next(
74        mut self: Pin<&mut Self>,
75        cx: &mut task::Context<'_>,
76    ) -> task::Poll<Option<Self::Item>> {
77        match Pin::new(&mut self.inner).poll_next(cx) {
78            task::Poll::Ready(Some(result)) => task::Poll::Ready(Some(decode_result(result))),
79            task::Poll::Ready(None) => task::Poll::Ready(None),
80            task::Poll::Pending => task::Poll::Pending,
81        }
82    }
83}
84
85pub struct RustexClient {
86    inner: ConvexClient,
87}
88
89impl Clone for RustexClient {
90    fn clone(&self) -> Self {
91        Self {
92            inner: self.inner.clone(),
93        }
94    }
95}
96
97impl RustexClient {
98    #[tracing::instrument(name = "rustex_runtime.client.new", skip_all, fields(deployment_url))]
99    pub async fn new(deployment_url: &str) -> anyhow::Result<Self> {
100        debug!("connecting Convex client");
101        Ok(Self {
102            inner: ConvexClient::new(deployment_url).await?,
103        })
104    }
105
106    pub fn from_inner(inner: ConvexClient) -> Self {
107        Self { inner }
108    }
109
110    pub fn inner(&self) -> &ConvexClient {
111        &self.inner
112    }
113
114    pub fn inner_mut(&mut self) -> &mut ConvexClient {
115        &mut self.inner
116    }
117
118    pub fn into_inner(self) -> ConvexClient {
119        self.inner
120    }
121
122    pub async fn query<F>(
123        &mut self,
124        _function: F,
125        args: &F::Args,
126    ) -> Result<F::Output, RuntimeError>
127    where
128        F: QuerySpec,
129    {
130        let encoded_args = encode_args(args)?;
131        let span = tracing::info_span!("rustex_runtime.query", convex.function = F::PATH);
132        async move {
133            debug!(argument_count = encoded_args.len(), "executing typed query");
134            let result = self.inner.query(F::PATH, encoded_args).await?;
135            decode_result(result)
136        }
137        .instrument(span)
138        .await
139    }
140
141    pub async fn subscribe<F>(
142        &mut self,
143        _function: F,
144        args: &F::Args,
145    ) -> Result<TypedSubscription<F>, RuntimeError>
146    where
147        F: QuerySpec,
148    {
149        let encoded_args = encode_args(args)?;
150        let span = tracing::info_span!("rustex_runtime.subscribe", convex.function = F::PATH);
151        async move {
152            debug!(
153                argument_count = encoded_args.len(),
154                "creating typed subscription"
155            );
156            let subscription = self.inner.subscribe(F::PATH, encoded_args).await?;
157            Ok(TypedSubscription::from_inner(subscription))
158        }
159        .instrument(span)
160        .await
161    }
162
163    pub async fn mutation<F>(
164        &mut self,
165        _function: F,
166        args: &F::Args,
167    ) -> Result<F::Output, RuntimeError>
168    where
169        F: MutationSpec,
170    {
171        let encoded_args = encode_args(args)?;
172        let span = tracing::info_span!("rustex_runtime.mutation", convex.function = F::PATH);
173        async move {
174            debug!(
175                argument_count = encoded_args.len(),
176                "executing typed mutation"
177            );
178            let result = self.inner.mutation(F::PATH, encoded_args).await?;
179            decode_result(result)
180        }
181        .instrument(span)
182        .await
183    }
184
185    pub async fn action<F>(
186        &mut self,
187        _function: F,
188        args: &F::Args,
189    ) -> Result<F::Output, RuntimeError>
190    where
191        F: ActionSpec,
192    {
193        let encoded_args = encode_args(args)?;
194        let span = tracing::info_span!("rustex_runtime.action", convex.function = F::PATH);
195        async move {
196            debug!(
197                argument_count = encoded_args.len(),
198                "executing typed action"
199            );
200            let result = self.inner.action(F::PATH, encoded_args).await?;
201            decode_result(result)
202        }
203        .instrument(span)
204        .await
205    }
206
207    pub fn watch_all(&self) -> QuerySetSubscription {
208        self.inner.watch_all()
209    }
210}
211
212#[derive(Debug, Error)]
213pub enum RuntimeError {
214    #[error(transparent)]
215    Transport(#[from] anyhow::Error),
216    #[error("Convex function returned an error message: {0}")]
217    FunctionMessage(String),
218    #[error("Convex function raised an application error: {message}")]
219    ConvexError {
220        message: String,
221        data: serde_json::Value,
222    },
223    #[error("arguments must serialize to an object or null")]
224    InvalidArgsShape,
225    #[error(transparent)]
226    Serde(#[from] serde_json::Error),
227}
228
229pub fn init_default_tracing() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
230    fmt()
231        .event_format(FlatLogFormat::default())
232        .with_env_filter(EnvFilter::from_default_env())
233        .try_init()
234}
235
236#[derive(Clone, Debug)]
237struct FlatLogFormat {
238    timer: UtcTime<time::format_description::OwnedFormatItem>,
239}
240
241impl Default for FlatLogFormat {
242    fn default() -> Self {
243        Self {
244            timer: UtcTime::new(
245                format_description!("[year]-[month]-[day]T[hour]:[minute]:[second]Z").into(),
246            ),
247        }
248    }
249}
250
251impl<S, N> FormatEvent<S, N> for FlatLogFormat
252where
253    S: tracing::Subscriber + for<'lookup> LookupSpan<'lookup>,
254    N: for<'writer> FormatFields<'writer> + 'static,
255{
256    fn format_event(
257        &self,
258        ctx: &FmtContext<'_, S, N>,
259        mut writer: Writer<'_>,
260        event: &tracing::Event<'_>,
261    ) -> stdfmt::Result {
262        self.timer.format_time(&mut writer)?;
263        write_level(&mut writer, event.metadata().level())?;
264        ctx.field_format().format_fields(writer.by_ref(), event)?;
265        writeln!(writer)
266    }
267}
268
269fn write_level(writer: &mut Writer<'_>, level: &tracing::Level) -> stdfmt::Result {
270    if writer.has_ansi_escapes() {
271        let color = match *level {
272            tracing::Level::ERROR => "\x1b[31m",
273            tracing::Level::WARN => "\x1b[33m",
274            tracing::Level::INFO => "\x1b[32m",
275            tracing::Level::DEBUG => "\x1b[34m",
276            tracing::Level::TRACE => "\x1b[35m",
277        };
278        write!(writer, " {}{:>5}\x1b[0m ", color, level)
279    } else {
280        write!(writer, " {:>5} ", level)
281    }
282}
283
284#[tracing::instrument(name = "rustex_runtime.encode_args", skip_all)]
285pub fn encode_args<T: Serialize>(args: &T) -> Result<BTreeMap<String, Value>, RuntimeError> {
286    let json = serde_json::to_value(args)?;
287    match json {
288        serde_json::Value::Null => Ok(BTreeMap::new()),
289        serde_json::Value::Object(map) => map
290            .into_iter()
291            .map(|(key, value)| Ok((key, Value::try_from(value)?)))
292            .collect::<Result<BTreeMap<_, _>, RuntimeError>>()
293            .inspect(|encoded| trace!(argument_count = encoded.len(), "encoded Convex args")),
294        _ => Err(RuntimeError::InvalidArgsShape),
295    }
296}
297
298#[tracing::instrument(name = "rustex_runtime.decode_result", skip_all)]
299pub fn decode_result<T: DeserializeOwned>(result: FunctionResult) -> Result<T, RuntimeError> {
300    match result {
301        FunctionResult::Value(value) => {
302            let json: serde_json::Value = value.into();
303            trace!("deserializing Convex function value");
304            Ok(serde_json::from_value(json)?)
305        }
306        FunctionResult::ErrorMessage(message) => {
307            debug!("Convex function returned an error message");
308            Err(RuntimeError::FunctionMessage(message))
309        }
310        FunctionResult::ConvexError(error) => Err(RuntimeError::ConvexError {
311            message: error.message,
312            data: error.data.into(),
313        }),
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::{RuntimeError, decode_result, encode_args};
320    use convex::{FunctionResult, Value};
321    use serde::{Deserialize, Serialize};
322    use std::collections::BTreeMap;
323
324    #[derive(Debug, Serialize)]
325    struct AddArgs {
326        author: String,
327        done: bool,
328    }
329
330    #[derive(Debug, Deserialize, PartialEq)]
331    struct AddResponse {
332        id: String,
333    }
334
335    #[test]
336    fn encode_args_serializes_structs_to_convex_values() {
337        let args = AddArgs {
338            author: "alice".into(),
339            done: true,
340        };
341
342        let encoded = encode_args(&args).expect("args should encode");
343        assert!(matches!(encoded.get("author"), Some(Value::String(value)) if value == "alice"));
344        assert!(matches!(encoded.get("done"), Some(Value::Boolean(true))));
345    }
346
347    #[test]
348    fn encode_args_allows_null_as_empty_object() {
349        let encoded = encode_args(&()).expect("unit should encode");
350        assert!(encoded.is_empty());
351    }
352
353    #[test]
354    fn decode_result_deserializes_typed_payloads() {
355        let mut object = BTreeMap::new();
356        object.insert("id".into(), Value::String("abc".into()));
357
358        let decoded: AddResponse =
359            decode_result(FunctionResult::Value(Value::Object(object))).expect("decode");
360        assert_eq!(decoded, AddResponse { id: "abc".into() });
361    }
362
363    #[test]
364    fn decode_result_surfaces_function_errors() {
365        let error = decode_result::<serde_json::Value>(FunctionResult::ErrorMessage("boom".into()))
366            .expect_err("error expected");
367
368        assert!(matches!(error, RuntimeError::FunctionMessage(message) if message == "boom"));
369    }
370}