Skip to main content

yash_env/
function.rs

1// This file is part of yash, an extended POSIX shell.
2// Copyright (C) 2021 WATANABE Yuki
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8//
9// This program is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12// GNU General Public License for more details.
13//
14// You should have received a copy of the GNU General Public License
15// along with this program.  If not, see <https://www.gnu.org/licenses/>.
16
17//! Type definitions for functions.
18//!
19//! This module provides data types for defining shell functions.
20
21use crate::Env;
22use crate::source::Location;
23use std::borrow::Borrow;
24use std::collections::HashSet;
25use std::fmt::Debug;
26use std::fmt::Display;
27use std::hash::Hash;
28use std::hash::Hasher;
29use std::iter::FusedIterator;
30use std::pin::Pin;
31use std::rc::Rc;
32use thiserror::Error;
33
34/// Trait for the body of a [`Function`]
35pub trait FunctionBody<S>: Debug + Display {
36    /// Executes the function body in the given environment.
37    ///
38    /// The implementation of this method is expected to update
39    /// `env.exit_status` reflecting the result of the function execution.
40    #[allow(async_fn_in_trait)] // We don't support Send
41    async fn execute(&self, env: &mut Env<S>) -> crate::semantics::Result;
42}
43
44/// Dyn-compatible adapter for the [`FunctionBody`] trait
45///
46/// This is a dyn-compatible version of the [`FunctionBody`] trait.
47///
48/// This trait is automatically implemented for all types that implement
49/// [`FunctionBody`].
50pub trait FunctionBodyObject<S>: Debug + Display {
51    /// Executes the function body in the given environment.
52    ///
53    /// The implementation of this method is expected to update
54    /// `env.exit_status` reflecting the result of the function execution.
55    fn execute<'a>(
56        &'a self,
57        env: &'a mut Env<S>,
58    ) -> Pin<Box<dyn Future<Output = crate::semantics::Result> + 'a>>;
59}
60
61impl<S, T: FunctionBody<S> + ?Sized> FunctionBodyObject<S> for T {
62    fn execute<'a>(
63        &'a self,
64        env: &'a mut Env<S>,
65    ) -> Pin<Box<dyn Future<Output = crate::semantics::Result> + 'a>> {
66        Box::pin(self.execute(env))
67    }
68}
69
70/// Definition of a function.
71pub struct Function<S> {
72    /// String that identifies the function.
73    pub name: String,
74
75    /// Command that is executed when the function is called.
76    ///
77    /// This is wrapped in `Rc` so that we don't have to clone the entire
78    /// command when we define a function. The function definition command only
79    /// clones the `Rc` object from the abstract syntax tree to create a
80    /// `Function` object.
81    pub body: Rc<dyn FunctionBodyObject<S>>,
82
83    /// Location of the function definition command that defined this function.
84    pub origin: Location,
85
86    /// Optional location where this function was made read-only.
87    ///
88    /// If this function is not read-only, `read_only_location` is `None`.
89    /// Otherwise, `read_only_location` is the location of the simple command
90    /// that executed the `readonly` built-in that made this function read-only.
91    pub read_only_location: Option<Location>,
92}
93
94impl<S> Function<S> {
95    /// Creates a new function.
96    ///
97    /// This is a convenience function for constructing a `Function` object.
98    /// The `read_only_location` is set to `None`.
99    #[inline]
100    #[must_use]
101    pub fn new<N: Into<String>, B: Into<Rc<dyn FunctionBodyObject<S>>>>(
102        name: N,
103        body: B,
104        origin: Location,
105    ) -> Self {
106        Function {
107            name: name.into(),
108            body: body.into(),
109            origin,
110            read_only_location: None,
111        }
112    }
113
114    /// Makes the function read-only.
115    ///
116    /// This is a convenience function for doing
117    /// `self.read_only_location = Some(location)` in a method chain.
118    #[inline]
119    #[must_use]
120    pub fn make_read_only(mut self, location: Location) -> Self {
121        self.read_only_location = Some(location);
122        self
123    }
124
125    /// Whether this function is read-only or not.
126    #[must_use]
127    pub const fn is_read_only(&self) -> bool {
128        self.read_only_location.is_some()
129    }
130}
131
132// Not derived automatically because S may not implement Clone
133impl<S> Clone for Function<S> {
134    fn clone(&self) -> Self {
135        Self {
136            name: self.name.clone(),
137            body: self.body.clone(),
138            origin: self.origin.clone(),
139            read_only_location: self.read_only_location.clone(),
140        }
141    }
142}
143
144/// Compares two functions for equality.
145///
146/// Two functions are considered equal if all their members are equal.
147/// This includes comparing the `body` members by pointer equality.
148impl<S> PartialEq for Function<S> {
149    fn eq(&self, other: &Self) -> bool {
150        self.name == other.name
151            && Rc::ptr_eq(&self.body, &other.body)
152            && self.origin == other.origin
153            && self.read_only_location == other.read_only_location
154    }
155}
156
157impl<S> Eq for Function<S> {}
158
159// Not derived automatically because S may not implement Debug
160impl<S> Debug for Function<S> {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        f.debug_struct("Function")
163            .field("name", &self.name)
164            .field("body", &self.body)
165            .field("origin", &self.origin)
166            .field("read_only_location", &self.read_only_location)
167            .finish()
168    }
169}
170
171/// Wrapper of [`Function`] for inserting into a hash set.
172///
173/// A `HashEntry` wraps a `Function` in `Rc` so that the `Function` object can
174/// outlive the execution of the function which may redefine or unset the
175/// function itself. A simple command that executes the function clones the
176/// `Rc` object from the function set and retains it until the command
177/// terminates.
178///
179/// The `Hash` and `PartialEq` implementation for `HashEntry` only compares
180/// the names of the functions.
181struct HashEntry<S>(Rc<Function<S>>);
182
183// Not derived automatically because S may not implement Clone
184impl<S> Clone for HashEntry<S> {
185    fn clone(&self) -> Self {
186        HashEntry(Rc::clone(&self.0))
187    }
188}
189
190impl<S> PartialEq for HashEntry<S> {
191    /// Compares the names of two hash entries.
192    ///
193    /// Members of [`Function`] other than `name` are not considered in this
194    /// function.
195    fn eq(&self, other: &HashEntry<S>) -> bool {
196        self.0.name == other.0.name
197    }
198}
199
200// Not derived automatically because S may not implement Eq
201impl<S> Eq for HashEntry<S> {}
202
203impl<S> Hash for HashEntry<S> {
204    /// Hashes the name of the function.
205    ///
206    /// Members of [`Function`] other than `name` are not considered in this
207    /// function.
208    fn hash<H: Hasher>(&self, state: &mut H) {
209        self.0.name.hash(state)
210    }
211}
212
213impl<S> Borrow<str> for HashEntry<S> {
214    fn borrow(&self) -> &str {
215        &self.0.name
216    }
217}
218
219// Not derived automatically because S may not implement Debug
220impl<S> Debug for HashEntry<S> {
221    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222        f.debug_tuple("HashEntry").field(&self.0).finish()
223    }
224}
225
226/// Collection of functions.
227pub struct FunctionSet<S> {
228    entries: HashSet<HashEntry<S>>,
229}
230
231// Not derived automatically because S may not implement Debug
232impl<S> Debug for FunctionSet<S> {
233    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234        f.debug_struct("FunctionSet")
235            .field("entries", &self.entries)
236            .finish()
237    }
238}
239
240// Not derived automatically because S may not implement Clone
241impl<S> Clone for FunctionSet<S> {
242    fn clone(&self) -> Self {
243        #[allow(clippy::mutable_key_type)]
244        let entries = self.entries.clone();
245        Self { entries }
246    }
247}
248
249// Not derived automatically because S may not implement Default
250impl<S> Default for FunctionSet<S> {
251    fn default() -> Self {
252        #[allow(clippy::mutable_key_type)]
253        let entries = HashSet::default();
254        Self { entries }
255    }
256}
257
258/// Error redefining a read-only function.
259#[derive(Error)]
260#[error("cannot redefine read-only function `{}`", .existing.name)]
261#[non_exhaustive]
262pub struct DefineError<S> {
263    /// Existing read-only function
264    pub existing: Rc<Function<S>>,
265    /// New function that tried to redefine the existing function
266    pub new: Rc<Function<S>>,
267}
268
269// Not derived automatically because S may not implement Clone, Debug, or PartialEq
270impl<S> Clone for DefineError<S> {
271    fn clone(&self) -> Self {
272        Self {
273            existing: self.existing.clone(),
274            new: self.new.clone(),
275        }
276    }
277}
278
279impl<S> Debug for DefineError<S> {
280    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281        f.debug_struct("DefineError")
282            .field("existing", &self.existing)
283            .field("new", &self.new)
284            .finish()
285    }
286}
287
288impl<S> PartialEq for DefineError<S> {
289    fn eq(&self, other: &Self) -> bool {
290        self.existing == other.existing && self.new == other.new
291    }
292}
293
294impl<S> Eq for DefineError<S> {}
295
296/// Error unsetting a read-only function.
297#[derive(Error)]
298#[error("cannot unset read-only function `{}`", .existing.name)]
299#[non_exhaustive]
300pub struct UnsetError<S> {
301    /// Existing read-only function
302    pub existing: Rc<Function<S>>,
303}
304
305// Not derived automatically because S may not implement Clone, Debug, or PartialEq
306impl<S> Clone for UnsetError<S> {
307    fn clone(&self) -> Self {
308        Self {
309            existing: self.existing.clone(),
310        }
311    }
312}
313
314impl<S> Debug for UnsetError<S> {
315    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316        f.debug_struct("UnsetError")
317            .field("existing", &self.existing)
318            .finish()
319    }
320}
321
322impl<S> PartialEq for UnsetError<S> {
323    fn eq(&self, other: &Self) -> bool {
324        self.existing == other.existing
325    }
326}
327
328impl<S> Eq for UnsetError<S> {}
329
330/// Unordered iterator over functions in a function set.
331///
332/// This iterator is created by [`FunctionSet::iter`].
333pub struct Iter<'a, S> {
334    inner: std::collections::hash_set::Iter<'a, HashEntry<S>>,
335}
336
337// Not derived automatically because S may not implement Clone or Debug
338impl<S> Clone for Iter<'_, S> {
339    fn clone(&self) -> Self {
340        Self {
341            inner: self.inner.clone(),
342        }
343    }
344}
345
346impl<S> Debug for Iter<'_, S> {
347    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
348        f.debug_struct("Iter").field("inner", &self.inner).finish()
349    }
350}
351
352impl<S> FunctionSet<S> {
353    /// Creates a new empty function set.
354    #[must_use]
355    pub fn new() -> Self {
356        FunctionSet::default()
357    }
358
359    /// Returns the function with the given name.
360    #[must_use]
361    pub fn get(&self, name: &str) -> Option<&Rc<Function<S>>> {
362        self.entries.get(name).map(|entry| &entry.0)
363    }
364
365    /// Returns the number of functions in the set.
366    #[inline]
367    #[must_use]
368    pub fn len(&self) -> usize {
369        self.entries.len()
370    }
371
372    /// Returns `true` if the set contains no functions.
373    #[inline]
374    #[must_use]
375    pub fn is_empty(&self) -> bool {
376        self.entries.is_empty()
377    }
378
379    /// Inserts a function into the set.
380    ///
381    /// If a function with the same name already exists, it is replaced and
382    /// returned unless it is read-only, in which case `DefineError` is
383    /// returned.
384    pub fn define<F: Into<Rc<Function<S>>>>(
385        &mut self,
386        function: F,
387    ) -> Result<Option<Rc<Function<S>>>, DefineError<S>> {
388        #[allow(clippy::mutable_key_type)]
389        fn inner<S>(
390            entries: &mut HashSet<HashEntry<S>>,
391            new: Rc<Function<S>>,
392        ) -> Result<Option<Rc<Function<S>>>, DefineError<S>> {
393            match entries.get(new.name.as_str()) {
394                Some(existing) if existing.0.is_read_only() => Err(DefineError {
395                    existing: Rc::clone(&existing.0),
396                    new,
397                }),
398
399                _ => Ok(entries.replace(HashEntry(new)).map(|entry| entry.0)),
400            }
401        }
402        inner(&mut self.entries, function.into())
403    }
404
405    /// Removes a function from the set.
406    ///
407    /// This function returns the previously defined function if it exists.
408    /// However, if the function is read-only, `UnsetError` is returned.
409    pub fn unset(&mut self, name: &str) -> Result<Option<Rc<Function<S>>>, UnsetError<S>> {
410        match self.entries.get(name) {
411            Some(entry) if entry.0.is_read_only() => Err(UnsetError {
412                existing: Rc::clone(&entry.0),
413            }),
414
415            _ => Ok(self.entries.take(name).map(|entry| entry.0)),
416        }
417    }
418
419    /// Returns an iterator over functions in the set.
420    ///
421    /// The order of iteration is not specified.
422    pub fn iter(&self) -> Iter<'_, S> {
423        let inner = self.entries.iter();
424        Iter { inner }
425    }
426}
427
428impl<'a, S> Iterator for Iter<'a, S> {
429    type Item = &'a Rc<Function<S>>;
430
431    fn next(&mut self) -> Option<Self::Item> {
432        self.inner.next().map(|entry| &entry.0)
433    }
434}
435
436impl<S> ExactSizeIterator for Iter<'_, S> {
437    #[inline]
438    fn len(&self) -> usize {
439        self.inner.len()
440    }
441}
442
443impl<S> FusedIterator for Iter<'_, S> {}
444
445impl<'a, S> IntoIterator for &'a FunctionSet<S> {
446    type Item = &'a Rc<Function<S>>;
447    type IntoIter = Iter<'a, S>;
448    fn into_iter(self) -> Self::IntoIter {
449        self.iter()
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456
457    #[derive(Clone, Debug)]
458    struct FunctionBodyStub;
459
460    impl std::fmt::Display for FunctionBodyStub {
461        fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
462            unreachable!()
463        }
464    }
465    impl<S> FunctionBody<S> for FunctionBodyStub {
466        async fn execute(&self, _: &mut Env<S>) -> crate::semantics::Result {
467            unreachable!()
468        }
469    }
470
471    fn function_body_stub<S>() -> Rc<dyn FunctionBodyObject<S>> {
472        Rc::new(FunctionBodyStub)
473    }
474
475    #[test]
476    fn defining_new_function() {
477        let mut set = FunctionSet::<()>::new();
478        let function = Rc::new(Function::new(
479            "foo",
480            function_body_stub(),
481            Location::dummy("foo"),
482        ));
483
484        let result = set.define(function.clone());
485        assert_eq!(result, Ok(None));
486        assert_eq!(set.get("foo"), Some(&function));
487    }
488
489    #[test]
490    fn redefining_existing_function() {
491        let mut set = FunctionSet::<()>::new();
492        let function1 = Rc::new(Function::new(
493            "foo",
494            function_body_stub(),
495            Location::dummy("foo 1"),
496        ));
497        let function2 = Rc::new(Function::new(
498            "foo",
499            function_body_stub(),
500            Location::dummy("foo 2"),
501        ));
502        set.define(function1.clone()).unwrap();
503
504        let result = set.define(function2.clone());
505        assert_eq!(result, Ok(Some(function1)));
506        assert_eq!(set.get("foo"), Some(&function2));
507    }
508
509    #[test]
510    fn redefining_readonly_function() {
511        let mut set = FunctionSet::<()>::new();
512        let function1 = Rc::new(
513            Function::new("foo", function_body_stub(), Location::dummy("foo 1"))
514                .make_read_only(Location::dummy("readonly")),
515        );
516        let function2 = Rc::new(Function::new(
517            "foo",
518            function_body_stub(),
519            Location::dummy("foo 2"),
520        ));
521        set.define(function1.clone()).unwrap();
522
523        let error = set.define(function2.clone()).unwrap_err();
524        assert_eq!(error.existing, function1);
525        assert_eq!(error.new, function2);
526        assert_eq!(set.get("foo"), Some(&function1));
527    }
528
529    #[test]
530    fn unsetting_existing_function() {
531        let mut set = FunctionSet::<()>::new();
532        let function = Rc::new(Function::new(
533            "foo",
534            function_body_stub(),
535            Location::dummy("foo"),
536        ));
537        set.define(function.clone()).unwrap();
538
539        let result = set.unset("foo").unwrap();
540        assert_eq!(result, Some(function));
541        assert_eq!(set.get("foo"), None);
542    }
543
544    #[test]
545    fn unsetting_nonexisting_function() {
546        let mut set = FunctionSet::<()>::new();
547
548        let result = set.unset("foo").unwrap();
549        assert_eq!(result, None);
550        assert_eq!(set.get("foo"), None);
551    }
552
553    #[test]
554    fn unsetting_readonly_function() {
555        let mut set = FunctionSet::<()>::new();
556        let function = Rc::new(
557            Function::new("foo", function_body_stub(), Location::dummy("foo"))
558                .make_read_only(Location::dummy("readonly")),
559        );
560        set.define(function.clone()).unwrap();
561
562        let error = set.unset("foo").unwrap_err();
563        assert_eq!(error.existing, function);
564    }
565
566    #[test]
567    fn iteration() {
568        let mut set = FunctionSet::<()>::new();
569        let function1 = Rc::new(Function::new(
570            "foo",
571            function_body_stub(),
572            Location::dummy("foo"),
573        ));
574        let function2 = Rc::new(Function::new(
575            "bar",
576            function_body_stub(),
577            Location::dummy("bar"),
578        ));
579        set.define(function1.clone()).unwrap();
580        set.define(function2.clone()).unwrap();
581
582        let functions = set.iter().collect::<Vec<_>>();
583        assert!(
584            functions[..] == [&function1, &function2] || functions[..] == [&function2, &function1],
585            "{functions:?}"
586        );
587    }
588}