Skip to main content

reifydb_routine/function/text/
replace.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::value::column::{ColumnWithName, buffer::ColumnBuffer, columns::Columns};
5use reifydb_type::{
6	util::bitvec::BitVec,
7	value::{constraint::bytes::MaxBytes, container::utf8::Utf8Container, r#type::Type},
8};
9
10use crate::routine::{Function, FunctionKind, Routine, RoutineInfo, context::FunctionContext, error::RoutineError};
11
12pub struct TextReplace {
13	info: RoutineInfo,
14}
15
16impl Default for TextReplace {
17	fn default() -> Self {
18		Self::new()
19	}
20}
21
22impl TextReplace {
23	pub fn new() -> Self {
24		Self {
25			info: RoutineInfo::new("text::replace"),
26		}
27	}
28}
29
30impl<'a> Routine<FunctionContext<'a>> for TextReplace {
31	fn info(&self) -> &RoutineInfo {
32		&self.info
33	}
34
35	fn return_type(&self, _input_types: &[Type]) -> Type {
36		Type::Utf8
37	}
38
39	fn execute(&self, ctx: &mut FunctionContext<'a>, args: &Columns) -> Result<Columns, RoutineError> {
40		if args.len() != 3 {
41			return Err(RoutineError::FunctionArityMismatch {
42				function: ctx.fragment.clone(),
43				expected: 3,
44				actual: args.len(),
45			});
46		}
47
48		let str_col = &args[0];
49		let from_col = &args[1];
50		let to_col = &args[2];
51
52		let (str_data, str_bv) = str_col.unwrap_option();
53		let (from_data, from_bv) = from_col.unwrap_option();
54		let (to_data, to_bv) = to_col.unwrap_option();
55		let row_count = str_data.len();
56
57		match (str_data, from_data, to_data) {
58			(
59				ColumnBuffer::Utf8 {
60					container: str_container,
61					..
62				},
63				ColumnBuffer::Utf8 {
64					container: from_container,
65					..
66				},
67				ColumnBuffer::Utf8 {
68					container: to_container,
69					..
70				},
71			) => {
72				let mut result_data = Vec::with_capacity(row_count);
73
74				for i in 0..row_count {
75					if str_container.is_defined(i)
76						&& from_container.is_defined(i) && to_container.is_defined(i)
77					{
78						let s = str_container.get(i).unwrap();
79						let from = from_container.get(i).unwrap();
80						let to = to_container.get(i).unwrap();
81						result_data.push(s.replace(from, to));
82					} else {
83						result_data.push(String::new());
84					}
85				}
86
87				let result_col_data = ColumnBuffer::Utf8 {
88					container: Utf8Container::new(result_data),
89					max_bytes: MaxBytes::MAX,
90				};
91
92				let mut combined_bv: Option<BitVec> = None;
93				for bv in [str_bv, from_bv, to_bv].into_iter().flatten() {
94					combined_bv = Some(match combined_bv {
95						Some(existing) => existing.and(bv),
96						None => bv.clone(),
97					});
98				}
99
100				let final_data = match combined_bv {
101					Some(bv) => ColumnBuffer::Option {
102						inner: Box::new(result_col_data),
103						bitvec: bv,
104					},
105					None => result_col_data,
106				};
107				Ok(Columns::new(vec![ColumnWithName::new(ctx.fragment.clone(), final_data)]))
108			}
109			(
110				ColumnBuffer::Utf8 {
111					..
112				},
113				ColumnBuffer::Utf8 {
114					..
115				},
116				other,
117			) => Err(RoutineError::FunctionInvalidArgumentType {
118				function: ctx.fragment.clone(),
119				argument_index: 2,
120				expected: vec![Type::Utf8],
121				actual: other.get_type(),
122			}),
123			(
124				ColumnBuffer::Utf8 {
125					..
126				},
127				other,
128				_,
129			) => Err(RoutineError::FunctionInvalidArgumentType {
130				function: ctx.fragment.clone(),
131				argument_index: 1,
132				expected: vec![Type::Utf8],
133				actual: other.get_type(),
134			}),
135			(other, _, _) => Err(RoutineError::FunctionInvalidArgumentType {
136				function: ctx.fragment.clone(),
137				argument_index: 0,
138				expected: vec![Type::Utf8],
139				actual: other.get_type(),
140			}),
141		}
142	}
143}
144
145impl Function for TextReplace {
146	fn kinds(&self) -> &[FunctionKind] {
147		&[FunctionKind::Scalar]
148	}
149}