1use std::fmt as stdfmt;
2use std::{collections::BTreeMap, marker::PhantomData, pin::Pin, task};
3
4use convex::{
5 ConvexClient, FunctionResult, QuerySetSubscription, QuerySubscription, SubscriberId, Value,
6};
7use futures_core::Stream;
8use serde::{Serialize, de::DeserializeOwned};
9use thiserror::Error;
10use time::macros::format_description;
11use tracing::{Instrument, debug, trace};
12use tracing_subscriber::fmt::format::Writer;
13use tracing_subscriber::fmt::time::{FormatTime, UtcTime};
14use tracing_subscriber::fmt::{FmtContext, FormatEvent, FormatFields};
15use tracing_subscriber::registry::LookupSpan;
16use tracing_subscriber::{EnvFilter, fmt};
17
18pub trait FunctionSpec {
19 type Args: Serialize;
20 type Output: DeserializeOwned;
21
22 const PATH: &'static str;
23}
24
25pub trait QuerySpec: FunctionSpec {}
26pub trait MutationSpec: FunctionSpec {}
27pub trait ActionSpec: FunctionSpec {}
28
29pub struct TypedSubscription<F> {
30 inner: QuerySubscription,
31 marker: PhantomData<fn() -> F>,
32}
33
34impl<F> TypedSubscription<F> {
35 pub fn from_inner(inner: QuerySubscription) -> Self {
36 Self {
37 inner,
38 marker: PhantomData,
39 }
40 }
41
42 pub fn id(&self) -> &SubscriberId {
43 self.inner.id()
44 }
45
46 pub fn inner(&self) -> &QuerySubscription {
47 &self.inner
48 }
49
50 pub fn inner_mut(&mut self) -> &mut QuerySubscription {
51 &mut self.inner
52 }
53
54 pub fn into_inner(self) -> QuerySubscription {
55 self.inner
56 }
57}
58
59impl<F> std::fmt::Debug for TypedSubscription<F> {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 f.debug_struct("TypedSubscription")
62 .field("subscriber_id", self.id())
63 .finish()
64 }
65}
66
67impl<F> Stream for TypedSubscription<F>
68where
69 F: QuerySpec,
70{
71 type Item = Result<F::Output, RuntimeError>;
72
73 fn poll_next(
74 mut self: Pin<&mut Self>,
75 cx: &mut task::Context<'_>,
76 ) -> task::Poll<Option<Self::Item>> {
77 match Pin::new(&mut self.inner).poll_next(cx) {
78 task::Poll::Ready(Some(result)) => task::Poll::Ready(Some(decode_result(result))),
79 task::Poll::Ready(None) => task::Poll::Ready(None),
80 task::Poll::Pending => task::Poll::Pending,
81 }
82 }
83}
84
85pub struct RustexClient {
86 inner: ConvexClient,
87}
88
89impl Clone for RustexClient {
90 fn clone(&self) -> Self {
91 Self {
92 inner: self.inner.clone(),
93 }
94 }
95}
96
97impl RustexClient {
98 #[tracing::instrument(name = "rustex_runtime.client.new", skip_all, fields(deployment_url))]
99 pub async fn new(deployment_url: &str) -> anyhow::Result<Self> {
100 debug!("connecting Convex client");
101 Ok(Self {
102 inner: ConvexClient::new(deployment_url).await?,
103 })
104 }
105
106 pub fn from_inner(inner: ConvexClient) -> Self {
107 Self { inner }
108 }
109
110 pub fn inner(&self) -> &ConvexClient {
111 &self.inner
112 }
113
114 pub fn inner_mut(&mut self) -> &mut ConvexClient {
115 &mut self.inner
116 }
117
118 pub fn into_inner(self) -> ConvexClient {
119 self.inner
120 }
121
122 pub async fn query<F>(
123 &mut self,
124 _function: F,
125 args: &F::Args,
126 ) -> Result<F::Output, RuntimeError>
127 where
128 F: QuerySpec,
129 {
130 let encoded_args = encode_args(args)?;
131 let span = tracing::info_span!("rustex_runtime.query", convex.function = F::PATH);
132 async move {
133 debug!(argument_count = encoded_args.len(), "executing typed query");
134 let result = self.inner.query(F::PATH, encoded_args).await?;
135 decode_result(result)
136 }
137 .instrument(span)
138 .await
139 }
140
141 pub async fn subscribe<F>(
142 &mut self,
143 _function: F,
144 args: &F::Args,
145 ) -> Result<TypedSubscription<F>, RuntimeError>
146 where
147 F: QuerySpec,
148 {
149 let encoded_args = encode_args(args)?;
150 let span = tracing::info_span!("rustex_runtime.subscribe", convex.function = F::PATH);
151 async move {
152 debug!(
153 argument_count = encoded_args.len(),
154 "creating typed subscription"
155 );
156 let subscription = self.inner.subscribe(F::PATH, encoded_args).await?;
157 Ok(TypedSubscription::from_inner(subscription))
158 }
159 .instrument(span)
160 .await
161 }
162
163 pub async fn mutation<F>(
164 &mut self,
165 _function: F,
166 args: &F::Args,
167 ) -> Result<F::Output, RuntimeError>
168 where
169 F: MutationSpec,
170 {
171 let encoded_args = encode_args(args)?;
172 let span = tracing::info_span!("rustex_runtime.mutation", convex.function = F::PATH);
173 async move {
174 debug!(
175 argument_count = encoded_args.len(),
176 "executing typed mutation"
177 );
178 let result = self.inner.mutation(F::PATH, encoded_args).await?;
179 decode_result(result)
180 }
181 .instrument(span)
182 .await
183 }
184
185 pub async fn action<F>(
186 &mut self,
187 _function: F,
188 args: &F::Args,
189 ) -> Result<F::Output, RuntimeError>
190 where
191 F: ActionSpec,
192 {
193 let encoded_args = encode_args(args)?;
194 let span = tracing::info_span!("rustex_runtime.action", convex.function = F::PATH);
195 async move {
196 debug!(
197 argument_count = encoded_args.len(),
198 "executing typed action"
199 );
200 let result = self.inner.action(F::PATH, encoded_args).await?;
201 decode_result(result)
202 }
203 .instrument(span)
204 .await
205 }
206
207 pub fn watch_all(&self) -> QuerySetSubscription {
208 self.inner.watch_all()
209 }
210}
211
212#[derive(Debug, Error)]
213pub enum RuntimeError {
214 #[error(transparent)]
215 Transport(#[from] anyhow::Error),
216 #[error("Convex function returned an error message: {0}")]
217 FunctionMessage(String),
218 #[error("Convex function raised an application error: {message}")]
219 ConvexError {
220 message: String,
221 data: serde_json::Value,
222 },
223 #[error("arguments must serialize to an object or null")]
224 InvalidArgsShape,
225 #[error(transparent)]
226 Serde(#[from] serde_json::Error),
227}
228
229pub fn init_default_tracing() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
230 fmt()
231 .event_format(FlatLogFormat::default())
232 .with_env_filter(EnvFilter::from_default_env())
233 .try_init()
234}
235
236#[derive(Clone, Debug)]
237struct FlatLogFormat {
238 timer: UtcTime<time::format_description::OwnedFormatItem>,
239}
240
241impl Default for FlatLogFormat {
242 fn default() -> Self {
243 Self {
244 timer: UtcTime::new(
245 format_description!("[year]-[month]-[day]T[hour]:[minute]:[second]Z").into(),
246 ),
247 }
248 }
249}
250
251impl<S, N> FormatEvent<S, N> for FlatLogFormat
252where
253 S: tracing::Subscriber + for<'lookup> LookupSpan<'lookup>,
254 N: for<'writer> FormatFields<'writer> + 'static,
255{
256 fn format_event(
257 &self,
258 ctx: &FmtContext<'_, S, N>,
259 mut writer: Writer<'_>,
260 event: &tracing::Event<'_>,
261 ) -> stdfmt::Result {
262 self.timer.format_time(&mut writer)?;
263 write_level(&mut writer, event.metadata().level())?;
264 ctx.field_format().format_fields(writer.by_ref(), event)?;
265 writeln!(writer)
266 }
267}
268
269fn write_level(writer: &mut Writer<'_>, level: &tracing::Level) -> stdfmt::Result {
270 if writer.has_ansi_escapes() {
271 let color = match *level {
272 tracing::Level::ERROR => "\x1b[31m",
273 tracing::Level::WARN => "\x1b[33m",
274 tracing::Level::INFO => "\x1b[32m",
275 tracing::Level::DEBUG => "\x1b[34m",
276 tracing::Level::TRACE => "\x1b[35m",
277 };
278 write!(writer, " {}{:>5}\x1b[0m ", color, level)
279 } else {
280 write!(writer, " {:>5} ", level)
281 }
282}
283
284#[tracing::instrument(name = "rustex_runtime.encode_args", skip_all)]
285pub fn encode_args<T: Serialize>(args: &T) -> Result<BTreeMap<String, Value>, RuntimeError> {
286 let json = serde_json::to_value(args)?;
287 match json {
288 serde_json::Value::Null => Ok(BTreeMap::new()),
289 serde_json::Value::Object(map) => map
290 .into_iter()
291 .map(|(key, value)| Ok((key, Value::try_from(value)?)))
292 .collect::<Result<BTreeMap<_, _>, RuntimeError>>()
293 .inspect(|encoded| trace!(argument_count = encoded.len(), "encoded Convex args")),
294 _ => Err(RuntimeError::InvalidArgsShape),
295 }
296}
297
298#[tracing::instrument(name = "rustex_runtime.decode_result", skip_all)]
299pub fn decode_result<T: DeserializeOwned>(result: FunctionResult) -> Result<T, RuntimeError> {
300 match result {
301 FunctionResult::Value(value) => {
302 let json: serde_json::Value = value.into();
303 trace!("deserializing Convex function value");
304 Ok(serde_json::from_value(json)?)
305 }
306 FunctionResult::ErrorMessage(message) => {
307 debug!("Convex function returned an error message");
308 Err(RuntimeError::FunctionMessage(message))
309 }
310 FunctionResult::ConvexError(error) => Err(RuntimeError::ConvexError {
311 message: error.message,
312 data: error.data.into(),
313 }),
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::{RuntimeError, decode_result, encode_args};
320 use convex::{FunctionResult, Value};
321 use serde::{Deserialize, Serialize};
322 use std::collections::BTreeMap;
323
324 #[derive(Debug, Serialize)]
325 struct AddArgs {
326 author: String,
327 done: bool,
328 }
329
330 #[derive(Debug, Deserialize, PartialEq)]
331 struct AddResponse {
332 id: String,
333 }
334
335 #[test]
336 fn encode_args_serializes_structs_to_convex_values() {
337 let args = AddArgs {
338 author: "alice".into(),
339 done: true,
340 };
341
342 let encoded = encode_args(&args).expect("args should encode");
343 assert!(matches!(encoded.get("author"), Some(Value::String(value)) if value == "alice"));
344 assert!(matches!(encoded.get("done"), Some(Value::Boolean(true))));
345 }
346
347 #[test]
348 fn encode_args_allows_null_as_empty_object() {
349 let encoded = encode_args(&()).expect("unit should encode");
350 assert!(encoded.is_empty());
351 }
352
353 #[test]
354 fn decode_result_deserializes_typed_payloads() {
355 let mut object = BTreeMap::new();
356 object.insert("id".into(), Value::String("abc".into()));
357
358 let decoded: AddResponse =
359 decode_result(FunctionResult::Value(Value::Object(object))).expect("decode");
360 assert_eq!(decoded, AddResponse { id: "abc".into() });
361 }
362
363 #[test]
364 fn decode_result_surfaces_function_errors() {
365 let error = decode_result::<serde_json::Value>(FunctionResult::ErrorMessage("boom".into()))
366 .expect_err("error expected");
367
368 assert!(matches!(error, RuntimeError::FunctionMessage(message) if message == "boom"));
369 }
370}