1use crate::ctx::Context;
2use crate::dbs::{Options, Transaction};
3use crate::doc::CursorDoc;
4use crate::err::Error;
5use crate::fnc;
6use crate::iam::Action;
7use crate::sql::fmt::Fmt;
8use crate::sql::idiom::Idiom;
9use crate::sql::script::Script;
10use crate::sql::value::Value;
11use crate::sql::Permission;
12use async_recursion::async_recursion;
13use futures::future::try_join_all;
14use revision::revisioned;
15use serde::{Deserialize, Serialize};
16use std::cmp::Ordering;
17use std::fmt;
18
19pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Function";
20
21#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
22#[serde(rename = "$surrealdb::private::sql::Function")]
23#[revisioned(revision = 1)]
24pub enum Function {
25 Normal(String, Vec<Value>),
26 Custom(String, Vec<Value>),
27 Script(Script, Vec<Value>),
28 }
30
31impl PartialOrd for Function {
32 #[inline]
33 fn partial_cmp(&self, _: &Self) -> Option<Ordering> {
34 None
35 }
36}
37
38impl Function {
39 pub fn name(&self) -> Option<&str> {
41 match self {
42 Self::Normal(n, _) => Some(n.as_str()),
43 Self::Custom(n, _) => Some(n.as_str()),
44 _ => None,
45 }
46 }
47 pub fn args(&self) -> &[Value] {
49 match self {
50 Self::Normal(_, a) => a,
51 Self::Custom(_, a) => a,
52 _ => &[],
53 }
54 }
55 pub fn to_idiom(&self) -> Idiom {
57 match self {
58 Self::Script(_, _) => "function".to_string().into(),
59 Self::Normal(f, _) => f.to_owned().into(),
60 Self::Custom(f, _) => format!("fn::{f}").into(),
61 }
62 }
63 pub fn aggregate(&self, val: Value) -> Self {
65 match self {
66 Self::Normal(n, a) => {
67 let mut a = a.to_owned();
68 match a.len() {
69 0 => a.insert(0, val),
70 _ => {
71 a.remove(0);
72 a.insert(0, val);
73 }
74 }
75 Self::Normal(n.to_owned(), a)
76 }
77 _ => unreachable!(),
78 }
79 }
80 pub fn is_custom(&self) -> bool {
82 matches!(self, Self::Custom(_, _))
83 }
84
85 pub fn is_script(&self) -> bool {
87 matches!(self, Self::Script(_, _))
88 }
89
90 pub fn is_rolling(&self) -> bool {
92 match self {
93 Self::Normal(f, _) if f == "count" => true,
94 Self::Normal(f, _) if f == "math::max" => true,
95 Self::Normal(f, _) if f == "math::mean" => true,
96 Self::Normal(f, _) if f == "math::min" => true,
97 Self::Normal(f, _) if f == "math::sum" => true,
98 Self::Normal(f, _) if f == "time::max" => true,
99 Self::Normal(f, _) if f == "time::min" => true,
100 _ => false,
101 }
102 }
103 pub fn is_aggregate(&self) -> bool {
105 match self {
106 Self::Normal(f, _) if f == "array::distinct" => true,
107 Self::Normal(f, _) if f == "array::first" => true,
108 Self::Normal(f, _) if f == "array::flatten" => true,
109 Self::Normal(f, _) if f == "array::group" => true,
110 Self::Normal(f, _) if f == "array::last" => true,
111 Self::Normal(f, _) if f == "count" => true,
112 Self::Normal(f, _) if f == "math::bottom" => true,
113 Self::Normal(f, _) if f == "math::interquartile" => true,
114 Self::Normal(f, _) if f == "math::max" => true,
115 Self::Normal(f, _) if f == "math::mean" => true,
116 Self::Normal(f, _) if f == "math::median" => true,
117 Self::Normal(f, _) if f == "math::midhinge" => true,
118 Self::Normal(f, _) if f == "math::min" => true,
119 Self::Normal(f, _) if f == "math::mode" => true,
120 Self::Normal(f, _) if f == "math::nearestrank" => true,
121 Self::Normal(f, _) if f == "math::percentile" => true,
122 Self::Normal(f, _) if f == "math::sample" => true,
123 Self::Normal(f, _) if f == "math::spread" => true,
124 Self::Normal(f, _) if f == "math::stddev" => true,
125 Self::Normal(f, _) if f == "math::sum" => true,
126 Self::Normal(f, _) if f == "math::top" => true,
127 Self::Normal(f, _) if f == "math::trimean" => true,
128 Self::Normal(f, _) if f == "math::variance" => true,
129 Self::Normal(f, _) if f == "time::max" => true,
130 Self::Normal(f, _) if f == "time::min" => true,
131 _ => false,
132 }
133 }
134}
135
136impl Function {
137 #[cfg_attr(not(target_arch = "wasm32"), async_recursion)]
139 #[cfg_attr(target_arch = "wasm32", async_recursion(?Send))]
140 pub(crate) async fn compute(
141 &self,
142 ctx: &Context<'_>,
143 opt: &Options,
144 txn: &Transaction,
145 doc: Option<&'async_recursion CursorDoc<'_>>,
146 ) -> Result<Value, Error> {
147 let opt = &opt.new_with_futures(true);
149 match self {
151 Self::Normal(s, x) => {
152 ctx.check_allowed_function(s)?;
154 let a = try_join_all(x.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
156 fnc::run(ctx, opt, txn, doc, s, a).await
158 }
159 Self::Custom(s, x) => {
160 ctx.check_allowed_function(format!("fn::{s}").as_str())?;
162 let val = {
164 let mut run = txn.lock().await;
166 run.get_and_cache_db_function(opt.ns(), opt.db(), s).await?
168 };
169 if opt.check_perms(Action::View) {
171 match &val.permissions {
172 Permission::Full => (),
173 Permission::None => {
174 return Err(Error::FunctionPermissions {
175 name: s.to_owned(),
176 })
177 }
178 Permission::Specific(e) => {
179 let opt = &opt.new_with_perms(false);
181 if !e.compute(ctx, opt, txn, doc).await?.is_truthy() {
183 return Err(Error::FunctionPermissions {
184 name: s.to_owned(),
185 });
186 }
187 }
188 }
189 }
190 if x.len() != val.args.len() {
193 return Err(Error::InvalidArguments {
194 name: format!("fn::{}", val.name),
195 message: match val.args.len() {
196 1 => String::from("The function expects 1 argument."),
197 l => format!("The function expects {l} arguments."),
198 },
199 });
200 }
201 let a = try_join_all(x.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
203 let mut ctx = Context::new(ctx);
205 for (val, (name, kind)) in a.into_iter().zip(&val.args) {
207 ctx.add_value(name.to_raw(), val.coerce_to(kind)?);
208 }
209 val.block.compute(&ctx, opt, txn, doc).await
211 }
212 #[allow(unused_variables)]
213 Self::Script(s, x) => {
214 #[cfg(feature = "scripting")]
215 {
216 ctx.check_allowed_scripting()?;
218 let a = try_join_all(x.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
220 fnc::script::run(ctx, opt, txn, doc, s, a).await
222 }
223 #[cfg(not(feature = "scripting"))]
224 {
225 Err(Error::InvalidScript {
226 message: String::from("Embedded functions are not enabled."),
227 })
228 }
229 }
230 }
231 }
232}
233
234impl fmt::Display for Function {
235 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
236 match self {
237 Self::Normal(s, e) => write!(f, "{s}({})", Fmt::comma_separated(e)),
238 Self::Custom(s, e) => write!(f, "fn::{s}({})", Fmt::comma_separated(e)),
239 Self::Script(s, e) => write!(f, "function({}) {{{s}}}", Fmt::comma_separated(e)),
240 }
241 }
242}