1use core::ffi::c_void;
2use core::mem::MaybeUninit;
3use core::ptr::NonNull;
4
5use crate::connection::Connection;
6use crate::error::{Error, Result};
7use crate::provider::{FeatureSet, FunctionFlags, Sqlite3Api, ValueType};
8use crate::value::{Value, ValueRef};
9
10pub struct Context<'p, P: Sqlite3Api> {
12 api: &'p P,
13 ctx: NonNull<P::Context>,
14}
15
16impl<'p, P: Sqlite3Api> Context<'p, P> {
17 pub(crate) fn new(api: &'p P, ctx: NonNull<P::Context>) -> Self {
18 Self { api, ctx }
19 }
20
21 pub fn result_null(&self) {
23 unsafe { self.api.result_null(self.ctx) }
24 }
25
26 pub fn result_int64(&self, v: i64) {
28 unsafe { self.api.result_int64(self.ctx, v) }
29 }
30
31 pub fn result_double(&self, v: f64) {
33 unsafe { self.api.result_double(self.ctx, v) }
34 }
35
36 pub fn result_text(&self, v: &str) {
38 unsafe { self.api.result_text(self.ctx, v) }
39 }
40
41 pub fn result_blob(&self, v: &[u8]) {
43 unsafe { self.api.result_blob(self.ctx, v) }
44 }
45
46 pub fn result_error(&self, msg: &str) {
48 unsafe { self.api.result_error(self.ctx, msg) }
49 }
50
51 pub fn result_value(&self, value: Value) {
53 match value {
54 Value::Null => self.result_null(),
55 Value::Integer(v) => self.result_int64(v),
56 Value::Float(v) => self.result_double(v),
57 Value::Text(v) => self.result_text(&v),
58 Value::Blob(v) => self.result_blob(&v),
59 }
60 }
61}
62
63const INLINE_ARGS: usize = 8;
64
65struct ArgBuffer<'a> {
66 inline: [MaybeUninit<ValueRef<'a>>; INLINE_ARGS],
67 len: usize,
68 heap: Option<Vec<ValueRef<'a>>>,
69}
70
71impl<'a> ArgBuffer<'a> {
72 fn new(argc: usize) -> Self {
73 let inline = unsafe {
74 MaybeUninit::<[MaybeUninit<ValueRef<'a>>; INLINE_ARGS]>::uninit().assume_init()
75 };
76 let heap = if argc > INLINE_ARGS {
77 Some(Vec::with_capacity(argc))
78 } else {
79 None
80 };
81 Self {
82 inline,
83 len: 0,
84 heap,
85 }
86 }
87
88 fn push(&mut self, value: ValueRef<'a>) {
89 if let Some(heap) = &mut self.heap {
90 heap.push(value);
91 return;
92 }
93 let slot = &mut self.inline[self.len];
94 slot.write(value);
95 self.len += 1;
96 }
97
98 fn as_slice(&self) -> &[ValueRef<'a>] {
99 if let Some(heap) = &self.heap {
100 return heap.as_slice();
101 }
102 unsafe {
103 core::slice::from_raw_parts(self.inline.as_ptr() as *const ValueRef<'a>, self.len)
104 }
105 }
106}
107
108unsafe fn value_ref_from_raw<'a, P: Sqlite3Api>(api: &P, value: NonNull<P::Value>) -> ValueRef<'a> {
109 match unsafe { api.value_type(value) } {
110 ValueType::Null => ValueRef::Null,
111 ValueType::Integer => ValueRef::Integer(unsafe { api.value_int64(value) }),
112 ValueType::Float => ValueRef::Float(unsafe { api.value_double(value) }),
113 ValueType::Text => unsafe { ValueRef::from_raw_text(api.value_text(value)) },
114 ValueType::Blob => unsafe { ValueRef::from_raw_blob(api.value_blob(value)) },
115 }
116}
117
118fn args_from_raw<'a, P: Sqlite3Api>(api: &P, argc: i32, argv: *mut *mut P::Value) -> ArgBuffer<'a> {
119 let argc = if argc < 0 { 0 } else { argc as usize };
120 let mut args = ArgBuffer::new(argc);
121 if argc == 0 || argv.is_null() {
122 return args;
123 }
124 let values = unsafe { core::slice::from_raw_parts(argv, argc) };
125 for value in values {
126 if let Some(ptr) = NonNull::new(*value) {
127 let arg = unsafe { value_ref_from_raw(api, ptr) };
128 args.push(arg);
129 } else {
130 args.push(ValueRef::Null);
131 }
132 }
133 args
134}
135
136fn set_error<P: Sqlite3Api>(ctx: &Context<'_, P>, err: &Error) {
137 let msg = err.message.as_deref().unwrap_or("sqlite function error");
138 ctx.result_error(msg);
139}
140
141struct ScalarState<P: Sqlite3Api, F> {
142 api: *const P,
143 func: F,
144}
145
146extern "C" fn scalar_trampoline<P, F>(ctx: *mut P::Context, argc: i32, argv: *mut *mut P::Value)
147where
148 P: Sqlite3Api,
149 F: for<'a> FnMut(&Context<'a, P>, &[ValueRef<'a>]) -> Result<Value> + Send + 'static,
150{
151 let ctx = match NonNull::new(ctx) {
152 Some(ctx) => ctx,
153 None => return,
154 };
155 let user_data = unsafe { P::user_data(ctx) };
156 if user_data.is_null() {
157 return;
158 }
159 let state = unsafe { &mut *(user_data as *mut ScalarState<P, F>) };
160 let api = unsafe { &*state.api };
161 let context = Context { api, ctx };
162 let args = args_from_raw(api, argc, argv);
163 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
164 (state.func)(&context, args.as_slice())
165 }));
166 match out {
167 Ok(Ok(value)) => context.result_value(value),
168 Ok(Err(err)) => set_error(&context, &err),
169 Err(_) => context.result_error("panic in sqlite function"),
170 }
171}
172
173struct AggregateState<P: Sqlite3Api, T, Init, Step, Final> {
174 api: *const P,
175 init: Init,
176 step: Step,
177 final_fn: Final,
178 _marker: core::marker::PhantomData<T>,
179}
180
181type AggStateSlot<T> = *mut T;
182
183unsafe fn get_agg_slot<P: Sqlite3Api, T>(
184 api: &P,
185 ctx: NonNull<P::Context>,
186 allocate: bool,
187) -> *mut AggStateSlot<T> {
188 let bytes = if allocate {
189 core::mem::size_of::<AggStateSlot<T>>()
190 } else {
191 0
192 };
193 unsafe { api.aggregate_context(ctx, bytes) as *mut AggStateSlot<T> }
194}
195
196extern "C" fn aggregate_step_trampoline<P, T, Init, Step, Final>(
197 ctx: *mut P::Context,
198 argc: i32,
199 argv: *mut *mut P::Value,
200) where
201 P: Sqlite3Api,
202 T: Send + 'static,
203 Init: Fn() -> T + Send + 'static,
204 Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
205 Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
206{
207 let ctx = match NonNull::new(ctx) {
208 Some(ctx) => ctx,
209 None => return,
210 };
211 let user_data = unsafe { P::user_data(ctx) };
212 if user_data.is_null() {
213 return;
214 }
215 let state = unsafe { &mut *(user_data as *mut AggregateState<P, T, Init, Step, Final>) };
216 let api = unsafe { &*state.api };
217 let context = Context { api, ctx };
218 let slot = unsafe { get_agg_slot::<P, T>(api, ctx, true) };
219 if slot.is_null() {
220 context.result_error("sqlite aggregate no memory");
221 return;
222 }
223 if unsafe { (*slot).is_null() } {
224 let init_out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| (state.init)()));
225 match init_out {
226 Ok(value) => {
227 unsafe { *slot = Box::into_raw(Box::new(value)) };
228 }
229 Err(_) => {
230 context.result_error("panic in sqlite aggregate init");
231 return;
232 }
233 }
234 }
235 let args = args_from_raw(api, argc, argv);
236 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
237 let value = unsafe { &mut **slot };
238 (state.step)(&context, value, args.as_slice())
239 }));
240 match out {
241 Ok(Ok(())) => {}
242 Ok(Err(err)) => set_error(&context, &err),
243 Err(_) => context.result_error("panic in sqlite aggregate"),
244 }
245}
246
247extern "C" fn aggregate_final_trampoline<P, T, Init, Step, Final>(ctx: *mut P::Context)
248where
249 P: Sqlite3Api,
250 T: Send + 'static,
251 Init: Fn() -> T + Send + 'static,
252 Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
253 Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
254{
255 let ctx = match NonNull::new(ctx) {
256 Some(ctx) => ctx,
257 None => return,
258 };
259 let user_data = unsafe { P::user_data(ctx) };
260 if user_data.is_null() {
261 return;
262 }
263 let state = unsafe { &mut *(user_data as *mut AggregateState<P, T, Init, Step, Final>) };
264 let api = unsafe { &*state.api };
265 let context = Context { api, ctx };
266 let slot = unsafe { get_agg_slot::<P, T>(api, ctx, false) };
267 if slot.is_null() {
268 context.result_null();
269 return;
270 }
271 let state_ptr = unsafe { *slot };
272 if state_ptr.is_null() {
273 context.result_null();
274 return;
275 }
276 unsafe { *slot = core::ptr::null_mut() };
277 let value = unsafe { *Box::from_raw(state_ptr) };
278 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
279 (state.final_fn)(&context, value)
280 }));
281 match out {
282 Ok(Ok(result)) => context.result_value(result),
283 Ok(Err(err)) => set_error(&context, &err),
284 Err(_) => context.result_error("panic in sqlite aggregate final"),
285 }
286}
287
288struct WindowState<P: Sqlite3Api, T, Init, Step, Inverse, ValueFn, Final> {
289 api: *const P,
290 init: Init,
291 step: Step,
292 inverse: Inverse,
293 value_fn: ValueFn,
294 final_fn: Final,
295 _marker: core::marker::PhantomData<T>,
296}
297
298extern "C" fn window_step_trampoline<P, T, Init, Step, Inverse, ValueFn, Final>(
299 ctx: *mut P::Context,
300 argc: i32,
301 argv: *mut *mut P::Value,
302) where
303 P: Sqlite3Api,
304 T: Send + 'static,
305 Init: Fn() -> T + Send + 'static,
306 Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
307 Inverse: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
308 ValueFn: for<'a> FnMut(&Context<'a, P>, &mut T) -> Result<Value> + Send + 'static,
309 Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
310{
311 let ctx = match NonNull::new(ctx) {
312 Some(ctx) => ctx,
313 None => return,
314 };
315 let user_data = unsafe { P::user_data(ctx) };
316 if user_data.is_null() {
317 return;
318 }
319 let state =
320 unsafe { &mut *(user_data as *mut WindowState<P, T, Init, Step, Inverse, ValueFn, Final>) };
321 let api = unsafe { &*state.api };
322 let context = Context { api, ctx };
323 let slot = unsafe { get_agg_slot::<P, T>(api, ctx, true) };
324 if slot.is_null() {
325 context.result_error("sqlite window no memory");
326 return;
327 }
328 if unsafe { (*slot).is_null() } {
329 let init_out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| (state.init)()));
330 match init_out {
331 Ok(value) => {
332 unsafe { *slot = Box::into_raw(Box::new(value)) };
333 }
334 Err(_) => {
335 context.result_error("panic in sqlite window init");
336 return;
337 }
338 }
339 }
340 let args = args_from_raw(api, argc, argv);
341 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
342 let value = unsafe { &mut **slot };
343 (state.step)(&context, value, args.as_slice())
344 }));
345 match out {
346 Ok(Ok(())) => {}
347 Ok(Err(err)) => set_error(&context, &err),
348 Err(_) => context.result_error("panic in sqlite window step"),
349 }
350}
351
352extern "C" fn window_inverse_trampoline<P, T, Init, Step, Inverse, ValueFn, Final>(
353 ctx: *mut P::Context,
354 argc: i32,
355 argv: *mut *mut P::Value,
356) where
357 P: Sqlite3Api,
358 T: Send + 'static,
359 Init: Fn() -> T + Send + 'static,
360 Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
361 Inverse: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
362 ValueFn: for<'a> FnMut(&Context<'a, P>, &mut T) -> Result<Value> + Send + 'static,
363 Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
364{
365 let ctx = match NonNull::new(ctx) {
366 Some(ctx) => ctx,
367 None => return,
368 };
369 let user_data = unsafe { P::user_data(ctx) };
370 if user_data.is_null() {
371 return;
372 }
373 let state =
374 unsafe { &mut *(user_data as *mut WindowState<P, T, Init, Step, Inverse, ValueFn, Final>) };
375 let api = unsafe { &*state.api };
376 let context = Context { api, ctx };
377 let slot = unsafe { get_agg_slot::<P, T>(api, ctx, true) };
378 if slot.is_null() {
379 context.result_error("sqlite window no memory");
380 return;
381 }
382 if unsafe { (*slot).is_null() } {
383 let init_out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| (state.init)()));
384 match init_out {
385 Ok(value) => {
386 unsafe { *slot = Box::into_raw(Box::new(value)) };
387 }
388 Err(_) => {
389 context.result_error("panic in sqlite window init");
390 return;
391 }
392 }
393 }
394 let args = args_from_raw(api, argc, argv);
395 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
396 let value = unsafe { &mut **slot };
397 (state.inverse)(&context, value, args.as_slice())
398 }));
399 match out {
400 Ok(Ok(())) => {}
401 Ok(Err(err)) => set_error(&context, &err),
402 Err(_) => context.result_error("panic in sqlite window inverse"),
403 }
404}
405
406extern "C" fn window_value_trampoline<P, T, Init, Step, Inverse, ValueFn, Final>(
407 ctx: *mut P::Context,
408) where
409 P: Sqlite3Api,
410 T: Send + 'static,
411 Init: Fn() -> T + Send + 'static,
412 Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
413 Inverse: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
414 ValueFn: for<'a> FnMut(&Context<'a, P>, &mut T) -> Result<Value> + Send + 'static,
415 Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
416{
417 let ctx = match NonNull::new(ctx) {
418 Some(ctx) => ctx,
419 None => return,
420 };
421 let user_data = unsafe { P::user_data(ctx) };
422 if user_data.is_null() {
423 return;
424 }
425 let state =
426 unsafe { &mut *(user_data as *mut WindowState<P, T, Init, Step, Inverse, ValueFn, Final>) };
427 let api = unsafe { &*state.api };
428 let context = Context { api, ctx };
429 let slot = unsafe { get_agg_slot::<P, T>(api, ctx, false) };
430 if slot.is_null() {
431 context.result_null();
432 return;
433 }
434 if unsafe { (*slot).is_null() } {
435 context.result_null();
436 return;
437 }
438 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
439 let value = unsafe { &mut **slot };
440 (state.value_fn)(&context, value)
441 }));
442 match out {
443 Ok(Ok(result)) => context.result_value(result),
444 Ok(Err(err)) => set_error(&context, &err),
445 Err(_) => context.result_error("panic in sqlite window value"),
446 }
447}
448
449extern "C" fn window_final_trampoline<P, T, Init, Step, Inverse, ValueFn, Final>(
450 ctx: *mut P::Context,
451) where
452 P: Sqlite3Api,
453 T: Send + 'static,
454 Init: Fn() -> T + Send + 'static,
455 Step: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
456 Inverse: for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
457 ValueFn: for<'a> FnMut(&Context<'a, P>, &mut T) -> Result<Value> + Send + 'static,
458 Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
459{
460 let ctx = match NonNull::new(ctx) {
461 Some(ctx) => ctx,
462 None => return,
463 };
464 let user_data = unsafe { P::user_data(ctx) };
465 if user_data.is_null() {
466 return;
467 }
468 let state =
469 unsafe { &mut *(user_data as *mut WindowState<P, T, Init, Step, Inverse, ValueFn, Final>) };
470 let api = unsafe { &*state.api };
471 let context = Context { api, ctx };
472 let slot = unsafe { get_agg_slot::<P, T>(api, ctx, false) };
473 if slot.is_null() {
474 context.result_null();
475 return;
476 }
477 let state_ptr = unsafe { *slot };
478 if state_ptr.is_null() {
479 context.result_null();
480 return;
481 }
482 unsafe { *slot = core::ptr::null_mut() };
483 let value = unsafe { *Box::from_raw(state_ptr) };
484 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
485 (state.final_fn)(&context, value)
486 }));
487 match out {
488 Ok(Ok(result)) => context.result_value(result),
489 Ok(Err(err)) => set_error(&context, &err),
490 Err(_) => context.result_error("panic in sqlite window final"),
491 }
492}
493
494extern "C" fn drop_boxed<T>(ptr: *mut c_void) {
495 if !ptr.is_null() {
496 unsafe { drop(Box::from_raw(ptr as *mut T)) };
497 }
498}
499
500impl<'p, P: Sqlite3Api> Connection<'p, P> {
501 pub fn create_scalar_function<F>(&self, name: &str, n_args: i32, func: F) -> Result<()>
503 where
504 F: for<'a> FnMut(&Context<'a, P>, &[ValueRef<'a>]) -> Result<Value> + Send + 'static,
505 {
506 if !self
507 .api
508 .feature_set()
509 .contains(FeatureSet::CREATE_FUNCTION_V2)
510 {
511 return Err(Error::feature_unavailable("create_function_v2 unsupported"));
512 }
513 let state = Box::new(ScalarState {
514 api: self.api as *const P,
515 func,
516 });
517 let user_data = Box::into_raw(state) as *mut c_void;
518 unsafe {
519 self.api.create_function_v2(
520 self.db,
521 name,
522 n_args,
523 FunctionFlags::empty(),
524 Some(scalar_trampoline::<P, F>),
525 None,
526 None,
527 user_data,
528 Some(drop_boxed::<ScalarState<P, F>>),
529 )
530 }
531 }
532
533 pub fn create_aggregate_function<T, Init, Step, Final>(
539 &self,
540 name: &str,
541 n_args: i32,
542 init: Init,
543 step: Step,
544 final_fn: Final,
545 ) -> Result<()>
546 where
547 T: Send + 'static,
548 Init: Fn() -> T + Send + 'static,
549 Step:
550 for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
551 Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
552 {
553 if !self
554 .api
555 .feature_set()
556 .contains(FeatureSet::CREATE_FUNCTION_V2)
557 {
558 return Err(Error::feature_unavailable("create_function_v2 unsupported"));
559 }
560 let state = Box::new(AggregateState::<P, T, Init, Step, Final> {
561 api: self.api as *const P,
562 init,
563 step,
564 final_fn,
565 _marker: core::marker::PhantomData,
566 });
567 let user_data = Box::into_raw(state) as *mut c_void;
568 unsafe {
569 self.api.create_function_v2(
570 self.db,
571 name,
572 n_args,
573 FunctionFlags::empty(),
574 None,
575 Some(aggregate_step_trampoline::<P, T, Init, Step, Final>),
576 Some(aggregate_final_trampoline::<P, T, Init, Step, Final>),
577 user_data,
578 Some(drop_boxed::<AggregateState<P, T, Init, Step, Final>>),
579 )
580 }
581 }
582
583 #[allow(clippy::too_many_arguments)]
589 pub fn create_window_function<T, Init, Step, Inverse, ValueFn, Final>(
590 &self,
591 name: &str,
592 n_args: i32,
593 init: Init,
594 step: Step,
595 inverse: Inverse,
596 value_fn: ValueFn,
597 final_fn: Final,
598 ) -> Result<()>
599 where
600 T: Send + 'static,
601 Init: Fn() -> T + Send + 'static,
602 Step:
603 for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
604 Inverse:
605 for<'a> FnMut(&Context<'a, P>, &mut T, &[ValueRef<'a>]) -> Result<()> + Send + 'static,
606 ValueFn: for<'a> FnMut(&Context<'a, P>, &mut T) -> Result<Value> + Send + 'static,
607 Final: for<'a> FnMut(&Context<'a, P>, T) -> Result<Value> + Send + 'static,
608 {
609 if !self
610 .api
611 .feature_set()
612 .contains(FeatureSet::WINDOW_FUNCTIONS)
613 {
614 return Err(Error::feature_unavailable("window functions unsupported"));
615 }
616 let state = Box::new(WindowState::<P, T, Init, Step, Inverse, ValueFn, Final> {
617 api: self.api as *const P,
618 init,
619 step,
620 inverse,
621 value_fn,
622 final_fn,
623 _marker: core::marker::PhantomData,
624 });
625 let user_data = Box::into_raw(state) as *mut c_void;
626 unsafe {
627 self.api.create_window_function(
628 self.db,
629 name,
630 n_args,
631 FunctionFlags::empty(),
632 Some(window_step_trampoline::<P, T, Init, Step, Inverse, ValueFn, Final>),
633 Some(window_final_trampoline::<P, T, Init, Step, Inverse, ValueFn, Final>),
634 Some(window_value_trampoline::<P, T, Init, Step, Inverse, ValueFn, Final>),
635 Some(window_inverse_trampoline::<P, T, Init, Step, Inverse, ValueFn, Final>),
636 user_data,
637 Some(drop_boxed::<WindowState<P, T, Init, Step, Inverse, ValueFn, Final>>),
638 )
639 }
640 }
641}
642
643#[cfg(test)]
644mod tests {
645 use super::ArgBuffer;
646 use crate::value::ValueRef;
647
648 #[test]
649 fn arg_buffer_inline() {
650 let mut buf = ArgBuffer::new(2);
651 buf.push(ValueRef::Integer(1));
652 buf.push(ValueRef::Integer(2));
653 assert_eq!(
654 buf.as_slice(),
655 &[ValueRef::Integer(1), ValueRef::Integer(2)]
656 );
657 }
658}