surrealdb/sql/
function.rs

1use crate::ctx::Context;
2use crate::dbs::{Options, Transaction};
3use crate::doc::CursorDoc;
4use crate::err::Error;
5use crate::fnc;
6use crate::iam::Action;
7use crate::sql::fmt::Fmt;
8use crate::sql::idiom::Idiom;
9use crate::sql::script::Script;
10use crate::sql::value::Value;
11use crate::sql::Permission;
12use async_recursion::async_recursion;
13use futures::future::try_join_all;
14use revision::revisioned;
15use serde::{Deserialize, Serialize};
16use std::cmp::Ordering;
17use std::fmt;
18
19pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Function";
20
21#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
22#[serde(rename = "$surrealdb::private::sql::Function")]
23#[revisioned(revision = 1)]
24pub enum Function {
25	Normal(String, Vec<Value>),
26	Custom(String, Vec<Value>),
27	Script(Script, Vec<Value>),
28	// Add new variants here
29}
30
31impl PartialOrd for Function {
32	#[inline]
33	fn partial_cmp(&self, _: &Self) -> Option<Ordering> {
34		None
35	}
36}
37
38impl Function {
39	/// Get function name if applicable
40	pub fn name(&self) -> Option<&str> {
41		match self {
42			Self::Normal(n, _) => Some(n.as_str()),
43			Self::Custom(n, _) => Some(n.as_str()),
44			_ => None,
45		}
46	}
47	/// Get function arguments if applicable
48	pub fn args(&self) -> &[Value] {
49		match self {
50			Self::Normal(_, a) => a,
51			Self::Custom(_, a) => a,
52			_ => &[],
53		}
54	}
55	/// Convert function call to a field name
56	pub fn to_idiom(&self) -> Idiom {
57		match self {
58			Self::Script(_, _) => "function".to_string().into(),
59			Self::Normal(f, _) => f.to_owned().into(),
60			Self::Custom(f, _) => format!("fn::{f}").into(),
61		}
62	}
63	/// Convert this function to an aggregate
64	pub fn aggregate(&self, val: Value) -> Self {
65		match self {
66			Self::Normal(n, a) => {
67				let mut a = a.to_owned();
68				match a.len() {
69					0 => a.insert(0, val),
70					_ => {
71						a.remove(0);
72						a.insert(0, val);
73					}
74				}
75				Self::Normal(n.to_owned(), a)
76			}
77			_ => unreachable!(),
78		}
79	}
80	/// Check if this function is a custom function
81	pub fn is_custom(&self) -> bool {
82		matches!(self, Self::Custom(_, _))
83	}
84
85	/// Check if this function is a scripting function
86	pub fn is_script(&self) -> bool {
87		matches!(self, Self::Script(_, _))
88	}
89
90	/// Check if this function is a rolling function
91	pub fn is_rolling(&self) -> bool {
92		match self {
93			Self::Normal(f, _) if f == "count" => true,
94			Self::Normal(f, _) if f == "math::max" => true,
95			Self::Normal(f, _) if f == "math::mean" => true,
96			Self::Normal(f, _) if f == "math::min" => true,
97			Self::Normal(f, _) if f == "math::sum" => true,
98			Self::Normal(f, _) if f == "time::max" => true,
99			Self::Normal(f, _) if f == "time::min" => true,
100			_ => false,
101		}
102	}
103	/// Check if this function is a grouping function
104	pub fn is_aggregate(&self) -> bool {
105		match self {
106			Self::Normal(f, _) if f == "array::distinct" => true,
107			Self::Normal(f, _) if f == "array::first" => true,
108			Self::Normal(f, _) if f == "array::flatten" => true,
109			Self::Normal(f, _) if f == "array::group" => true,
110			Self::Normal(f, _) if f == "array::last" => true,
111			Self::Normal(f, _) if f == "count" => true,
112			Self::Normal(f, _) if f == "math::bottom" => true,
113			Self::Normal(f, _) if f == "math::interquartile" => true,
114			Self::Normal(f, _) if f == "math::max" => true,
115			Self::Normal(f, _) if f == "math::mean" => true,
116			Self::Normal(f, _) if f == "math::median" => true,
117			Self::Normal(f, _) if f == "math::midhinge" => true,
118			Self::Normal(f, _) if f == "math::min" => true,
119			Self::Normal(f, _) if f == "math::mode" => true,
120			Self::Normal(f, _) if f == "math::nearestrank" => true,
121			Self::Normal(f, _) if f == "math::percentile" => true,
122			Self::Normal(f, _) if f == "math::sample" => true,
123			Self::Normal(f, _) if f == "math::spread" => true,
124			Self::Normal(f, _) if f == "math::stddev" => true,
125			Self::Normal(f, _) if f == "math::sum" => true,
126			Self::Normal(f, _) if f == "math::top" => true,
127			Self::Normal(f, _) if f == "math::trimean" => true,
128			Self::Normal(f, _) if f == "math::variance" => true,
129			Self::Normal(f, _) if f == "time::max" => true,
130			Self::Normal(f, _) if f == "time::min" => true,
131			_ => false,
132		}
133	}
134}
135
136impl Function {
137	/// Process this type returning a computed simple Value
138	#[cfg_attr(not(target_arch = "wasm32"), async_recursion)]
139	#[cfg_attr(target_arch = "wasm32", async_recursion(?Send))]
140	pub(crate) async fn compute(
141		&self,
142		ctx: &Context<'_>,
143		opt: &Options,
144		txn: &Transaction,
145		doc: Option<&'async_recursion CursorDoc<'_>>,
146	) -> Result<Value, Error> {
147		// Ensure futures are run
148		let opt = &opt.new_with_futures(true);
149		// Process the function type
150		match self {
151			Self::Normal(s, x) => {
152				// Check this function is allowed
153				ctx.check_allowed_function(s)?;
154				// Compute the function arguments
155				let a = try_join_all(x.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
156				// Run the normal function
157				fnc::run(ctx, opt, txn, doc, s, a).await
158			}
159			Self::Custom(s, x) => {
160				// Check this function is allowed
161				ctx.check_allowed_function(format!("fn::{s}").as_str())?;
162				// Get the function definition
163				let val = {
164					// Claim transaction
165					let mut run = txn.lock().await;
166					// Get the function definition
167					run.get_and_cache_db_function(opt.ns(), opt.db(), s).await?
168				};
169				// Check permissions
170				if opt.check_perms(Action::View) {
171					match &val.permissions {
172						Permission::Full => (),
173						Permission::None => {
174							return Err(Error::FunctionPermissions {
175								name: s.to_owned(),
176							})
177						}
178						Permission::Specific(e) => {
179							// Disable permissions
180							let opt = &opt.new_with_perms(false);
181							// Process the PERMISSION clause
182							if !e.compute(ctx, opt, txn, doc).await?.is_truthy() {
183								return Err(Error::FunctionPermissions {
184									name: s.to_owned(),
185								});
186							}
187						}
188					}
189				}
190				// Return the value
191				// Check the function arguments
192				if x.len() != val.args.len() {
193					return Err(Error::InvalidArguments {
194						name: format!("fn::{}", val.name),
195						message: match val.args.len() {
196							1 => String::from("The function expects 1 argument."),
197							l => format!("The function expects {l} arguments."),
198						},
199					});
200				}
201				// Compute the function arguments
202				let a = try_join_all(x.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
203				// Duplicate context
204				let mut ctx = Context::new(ctx);
205				// Process the function arguments
206				for (val, (name, kind)) in a.into_iter().zip(&val.args) {
207					ctx.add_value(name.to_raw(), val.coerce_to(kind)?);
208				}
209				// Run the custom function
210				val.block.compute(&ctx, opt, txn, doc).await
211			}
212			#[allow(unused_variables)]
213			Self::Script(s, x) => {
214				#[cfg(feature = "scripting")]
215				{
216					// Check if scripting is allowed
217					ctx.check_allowed_scripting()?;
218					// Compute the function arguments
219					let a = try_join_all(x.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
220					// Run the script function
221					fnc::script::run(ctx, opt, txn, doc, s, a).await
222				}
223				#[cfg(not(feature = "scripting"))]
224				{
225					Err(Error::InvalidScript {
226						message: String::from("Embedded functions are not enabled."),
227					})
228				}
229			}
230		}
231	}
232}
233
234impl fmt::Display for Function {
235	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
236		match self {
237			Self::Normal(s, e) => write!(f, "{s}({})", Fmt::comma_separated(e)),
238			Self::Custom(s, e) => write!(f, "fn::{s}({})", Fmt::comma_separated(e)),
239			Self::Script(s, e) => write!(f, "function({}) {{{s}}}", Fmt::comma_separated(e)),
240		}
241	}
242}