arroyo_udf_plugin/
async_udf.rs

1use arrow::array::{Array, ArrayBuilder, ArrayData, UInt64Builder};
2use arroyo_udf_common::async_udf::{FfiAsyncUdfHandle, QueueData, ResultMutex};
3use arroyo_udf_common::{ArrowDatum, FfiArrays};
4use futures::stream::StreamExt;
5use futures::stream::{FuturesOrdered, FuturesUnordered};
6use std::future::Future;
7use std::sync::{Arc, Mutex};
8use std::time::Duration;
9use tokio::select;
10use tokio::sync::mpsc::{channel, Receiver, Sender};
11use tokio::time::error::Elapsed;
12
13pub use arroyo_udf_common::async_udf::{DrainResult, SendableFfiAsyncUdfHandle};
14pub use async_ffi;
15pub use tokio;
16
17pub enum FuturesEnum<F: Future + Send + 'static> {
18    Ordered(FuturesOrdered<F>),
19    Unordered(FuturesUnordered<F>),
20}
21
22impl<T: Send, F: Future<Output = T> + Send + 'static> FuturesEnum<F> {
23    pub fn push_back(&mut self, f: F) {
24        match self {
25            FuturesEnum::Ordered(futures) => futures.push_back(f),
26            FuturesEnum::Unordered(futures) => futures.push(f),
27        }
28    }
29
30    pub async fn next(&mut self) -> Option<T> {
31        match self {
32            FuturesEnum::Ordered(futures) => futures.next().await,
33            FuturesEnum::Unordered(futures) => futures.next().await,
34        }
35    }
36
37    pub fn len(&self) -> usize {
38        match self {
39            FuturesEnum::Ordered(futures) => futures.len(),
40            FuturesEnum::Unordered(futures) => futures.len(),
41        }
42    }
43
44    pub fn is_empty(&self) -> bool {
45        match self {
46            FuturesEnum::Ordered(futures) => futures.is_empty(),
47            FuturesEnum::Unordered(futures) => futures.is_empty(),
48        }
49    }
50
51    pub fn is_ordered(&self) -> bool {
52        match self {
53            FuturesEnum::Ordered(_) => true,
54            FuturesEnum::Unordered(_) => false,
55        }
56    }
57}
58
59pub struct AsyncUdfHandle {
60    pub tx: Sender<QueueData>,
61    pub results: ResultMutex,
62}
63
64impl AsyncUdfHandle {
65    pub fn into_ffi(self) -> *mut FfiAsyncUdfHandle {
66        Box::leak(Box::new(self)) as *mut AsyncUdfHandle as *mut FfiAsyncUdfHandle
67    }
68}
69
70pub async fn send(handle: SendableFfiAsyncUdfHandle, id: u64, arrays: FfiArrays) -> bool {
71    let args = arrays.into_vec();
72
73    unsafe {
74        let handle = handle.ptr as *mut AsyncUdfHandle;
75        (*handle).tx.send((id, args))
76    }
77    .await
78    .is_ok()
79}
80
81pub fn drain_results(handle: SendableFfiAsyncUdfHandle) -> DrainResult {
82    let handle = unsafe { &mut *(handle.ptr as *mut AsyncUdfHandle) };
83    match handle.results.lock() {
84        Ok(mut data) => {
85            if data.0.is_empty() {
86                return DrainResult::None;
87            }
88
89            let ids = data.0.finish();
90            let results = data.1.finish();
91            DrainResult::Data(FfiArrays::from_vec(vec![ids.to_data(), results.to_data()]))
92        }
93        Err(_) => DrainResult::Error,
94    }
95}
96
97pub fn stop_runtime(handle: SendableFfiAsyncUdfHandle) {
98    let handle = unsafe { Box::from_raw(&mut *(handle.ptr as *mut AsyncUdfHandle)) };
99    // no-op, but explicit here to make clear the point of this function
100    drop(handle);
101}
102
103pub type OutputT = (u64, Result<ArrowDatum, Elapsed>);
104
105pub struct AsyncUdf<
106    F: Future<Output = OutputT> + Send + 'static,
107    FnT: Fn(u64, Duration, Vec<ArrayData>) -> F + Send,
108> {
109    futures: FuturesEnum<F>,
110    rx: Receiver<QueueData>,
111    results: ResultMutex,
112    func: FnT,
113    timeout: Duration,
114    allowed_in_flight: usize,
115}
116
117impl<
118        F: Future<Output = OutputT> + Send + 'static,
119        FnT: Fn(u64, Duration, Vec<ArrayData>) -> F + Send + 'static,
120    > AsyncUdf<F, FnT>
121{
122    pub fn new(
123        ordered: bool,
124        timeout: Duration,
125        allowed_in_flight: u32,
126        builder: Box<dyn ArrayBuilder>,
127        func: FnT,
128    ) -> (Self, AsyncUdfHandle) {
129        let (tx, rx) = channel(1);
130
131        let results = Arc::new(Mutex::new((UInt64Builder::new(), builder)));
132
133        let handle = AsyncUdfHandle {
134            tx,
135            results: results.clone(),
136        };
137
138        (
139            Self {
140                futures: if ordered {
141                    FuturesEnum::Ordered(FuturesOrdered::new())
142                } else {
143                    FuturesEnum::Unordered(FuturesUnordered::new())
144                },
145                rx,
146                results,
147                func,
148                timeout,
149                allowed_in_flight: allowed_in_flight as usize,
150            },
151            handle,
152        )
153    }
154
155    pub fn start(self) {
156        std::thread::spawn(move || {
157            let runtime = tokio::runtime::Builder::new_current_thread()
158                .enable_all()
159                .build()
160                .unwrap();
161            runtime.block_on(async move {
162                self.run().await;
163            })
164        });
165    }
166
167    async fn run(mut self) {
168        loop {
169            select! {
170                item = self.rx.recv(), if self.futures.len() < self.allowed_in_flight => {
171                    let Some((id, args)) = item else {
172                        break;
173                    };
174
175                    self.futures.push_back((self.func)(id, self.timeout, args));
176                }
177                Some((id, result)) = self.futures.next() => {
178                    self.handle_future(id, result).await;
179                }
180            }
181        }
182    }
183
184    async fn handle_future(&mut self, id: u64, result: Result<ArrowDatum, Elapsed>) {
185        let mut results = self.results.lock().unwrap();
186        match result {
187            Ok(value) => {
188                results.0.append_value(id);
189                value.append_to(&mut results.1);
190            }
191            Err(_) => {
192                if self.futures.is_ordered() {
193                    panic!("Ordered Async UDF timed out, currently panic to preserve ordering");
194                }
195                panic!("Async UDF timed out");
196            }
197        }
198    }
199}