1use crate::agent;
7use indexmap::IndexMap;
8use serde::{Deserialize, Serialize};
9use starlark::values::dict::DictRef as StarlarkDictRef;
10use starlark::values::float::UnpackFloat;
11use starlark::values::list::ListRef as StarlarkListRef;
12use starlark::values::{Heap as StarlarkHeap, UnpackValue, Value as StarlarkValue};
13use schemars::JsonSchema;
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema)]
20#[serde(untagged)]
21#[schemars(rename = "functions.expression.InputValue")]
22pub enum InputValue {
23 #[schemars(title = "RichContentPart")]
25 RichContentPart(agent::completions::message::RichContentPart),
26 #[schemars(title = "Object")]
28 Object(IndexMap<String, InputValue>),
29 #[schemars(title = "Array")]
31 Array(Vec<InputValue>),
32 #[schemars(title = "String")]
34 String(String),
35 #[schemars(title = "Integer")]
37 Integer(i64),
38 #[schemars(title = "Number")]
40 Number(f64),
41 #[schemars(title = "Boolean")]
43 Boolean(bool),
44}
45
46impl super::ToStarlarkValue for InputValue {
47 fn to_starlark_value<'v>(&self, heap: &'v StarlarkHeap) -> StarlarkValue<'v> {
48 match self {
49 InputValue::String(s) => s.to_starlark_value(heap),
50 InputValue::Integer(i) => i.to_starlark_value(heap),
51 InputValue::Number(f) => f.to_starlark_value(heap),
52 InputValue::Boolean(b) => b.to_starlark_value(heap),
53 InputValue::Object(map) => map.to_starlark_value(heap),
54 InputValue::Array(arr) => arr.to_starlark_value(heap),
55 InputValue::RichContentPart(part) => part.to_starlark_value(heap),
56 }
57 }
58}
59
60impl super::FromStarlarkValue for InputValue {
61 fn from_starlark_value(value: &StarlarkValue) -> Result<Self, super::ExpressionError> {
62 if value.is_none() {
63 return Err(super::ExpressionError::StarlarkConversionError("Input: expected value".into()));
64 }
65 if let Ok(Some(b)) = bool::unpack_value(*value) {
66 return Ok(InputValue::Boolean(b));
67 }
68 if let Ok(Some(i)) = i64::unpack_value(*value) {
69 return Ok(InputValue::Integer(i));
70 }
71 if let Ok(Some(UnpackFloat(f))) = UnpackFloat::unpack_value(*value) {
72 return Ok(InputValue::Number(f));
73 }
74 if let Ok(Some(s)) = <&str as UnpackValue>::unpack_value(*value) {
75 return Ok(InputValue::String(s.to_owned()));
76 }
77 if let Some(list) = StarlarkListRef::from_value(*value) {
78 let mut items = Vec::with_capacity(list.len());
79 for v in list.iter() {
80 items.push(InputValue::from_starlark_value(&v)?);
81 }
82 return Ok(InputValue::Array(items));
83 }
84 if let Some(dict) = StarlarkDictRef::from_value(*value) {
85 let mut type_value = None;
87 for (k, v) in dict.iter() {
88 if let Ok(Some("type")) = <&str as UnpackValue>::unpack_value(k) {
89 type_value = <&str as UnpackValue>::unpack_value(v).ok().flatten();
90 break;
91 }
92 }
93 if matches!(type_value, Some("text" | "image_url" | "input_audio" | "input_video" | "video_url" | "file")) {
94 if let Ok(part) = agent::completions::message::RichContentPart::from_starlark_value(value) {
95 return Ok(InputValue::RichContentPart(part));
96 }
97 }
98 let mut map = IndexMap::with_capacity(dict.len());
99 for (k, v) in dict.iter() {
100 let key = <&str as UnpackValue>::unpack_value(k)
101 .map_err(|e| super::ExpressionError::StarlarkConversionError(e.to_string()))?
102 .ok_or_else(|| super::ExpressionError::StarlarkConversionError("Input: expected string key".into()))?
103 .to_owned();
104 map.insert(key, InputValue::from_starlark_value(&v)?);
105 }
106 return Ok(InputValue::Object(map));
107 }
108 Err(super::ExpressionError::StarlarkConversionError(format!(
109 "Input: unsupported type: {}",
110 value.get_type()
111 )))
112 }
113}
114
115impl super::FromSpecial for InputValue {
116 fn from_special(
117 special: &super::Special,
118 params: &super::Params,
119 ) -> Result<Self, super::ExpressionError> {
120 match special {
121 super::Special::Input => {
122 let input = match params {
123 super::Params::Owned(o) => &o.input,
124 super::Params::Ref(r) => r.input,
125 };
126 Ok(input.clone())
127 }
128 super::Special::InputItemsOptionalContextMerge => {
129 let input = match params {
133 super::Params::Owned(o) => &o.input,
134 super::Params::Ref(r) => r.input,
135 };
136 let InputValue::Array(arr) = input else {
137 return Err(super::ExpressionError::UnsupportedSpecial);
138 };
139 let mut merged_items = Vec::new();
140 let mut context = None;
141 for (i, elem) in arr.iter().enumerate() {
142 let InputValue::Object(obj) = elem else {
143 return Err(super::ExpressionError::UnsupportedSpecial);
144 };
145 if let Some(InputValue::Array(items)) = obj.get("items") {
146 merged_items.extend(items.iter().cloned());
147 }
148 if i == 0 {
149 context = obj.get("context").cloned();
150 }
151 }
152 let mut result = IndexMap::new();
153 result.insert("items".to_string(), InputValue::Array(merged_items));
154 if let Some(ctx) = context {
155 result.insert("context".to_string(), ctx);
156 }
157 Ok(InputValue::Object(result))
158 }
159 _ => Err(super::ExpressionError::UnsupportedSpecial),
160 }
161 }
162}
163
164impl super::FromSpecial for Vec<InputValue> {
165 fn from_special(
166 special: &super::Special,
167 params: &super::Params,
168 ) -> Result<Self, super::ExpressionError> {
169 match special {
170 super::Special::InputItemsOptionalContextSplit => {
171 let input = match params {
175 super::Params::Owned(o) => &o.input,
176 super::Params::Ref(r) => r.input,
177 };
178 let InputValue::Object(obj) = input else {
179 return Err(super::ExpressionError::UnsupportedSpecial);
180 };
181 let Some(InputValue::Array(items)) = obj.get("items") else {
182 return Err(super::ExpressionError::UnsupportedSpecial);
183 };
184 let context = obj.get("context");
185 let mut result = Vec::with_capacity(items.len());
186 for item in items {
187 let mut sub = IndexMap::new();
188 sub.insert("items".to_string(), InputValue::Array(vec![item.clone()]));
189 if let Some(ctx) = context {
190 sub.insert("context".to_string(), ctx.clone());
191 }
192 result.push(InputValue::Object(sub));
193 }
194 Ok(result)
195 }
196 _ => Err(super::ExpressionError::UnsupportedSpecial),
197 }
198 }
199}
200
201impl Eq for InputValue {}
202
203impl std::hash::Hash for InputValue {
204 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
205 std::mem::discriminant(self).hash(state);
206 match self {
207 InputValue::RichContentPart(p) => p.hash(state),
208 InputValue::Object(map) => {
209 map.len().hash(state);
210 for (k, v) in map {
211 k.hash(state);
212 v.hash(state);
213 }
214 }
215 InputValue::Array(arr) => arr.hash(state),
216 InputValue::String(s) => s.hash(state),
217 InputValue::Integer(i) => i.hash(state),
218 InputValue::Number(f) => canonical_f64_bits(*f).hash(state),
219 InputValue::Boolean(b) => b.hash(state),
220 }
221 }
222}
223
224impl<'a> arbitrary::Arbitrary<'a> for InputValue {
225 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
226 if u.arbitrary().unwrap_or(false) {
227 if u.arbitrary()? {
229 let mut map = IndexMap::new();
230 while u.arbitrary().unwrap_or(false) {
231 map.insert(u.arbitrary::<String>()?, u.arbitrary()?);
232 }
233 Ok(InputValue::Object(map))
234 } else {
235 let mut arr = Vec::new();
236 while u.arbitrary().unwrap_or(false) {
237 arr.push(InputValue::arbitrary(u)?);
238 }
239 Ok(InputValue::Array(arr))
240 }
241 } else {
242 match u.int_in_range(0..=4)? {
244 0 => Ok(InputValue::RichContentPart(u.arbitrary()?)),
245 1 => Ok(InputValue::String(u.arbitrary()?)),
246 2 => Ok(InputValue::Integer(crate::arbitrary_util::arbitrary_i64(u)?)),
247 3 => Ok(InputValue::Number(crate::arbitrary_util::arbitrary_f64(u)?)),
248 _ => Ok(InputValue::Boolean(u.arbitrary()?)),
249 }
250 }
251 }
252}
253
254fn canonical_f64_bits(f: f64) -> u64 {
260 if f.is_nan() {
261 0x7FF8_0000_0000_0000 } else if f == 0.0 {
264 0u64
266 } else {
267 f.to_bits()
268 }
269}
270
271impl InputValue {
272 pub fn to_rich_content_parts(
277 self,
278 depth: usize,
279 ) -> impl Iterator<Item = agent::completions::message::RichContentPart> {
280 enum Iter {
281 RichContentPart(RichContentPartIter),
282 Object(Box<ObjectIter>),
283 Array(Box<ArrayIter>),
284 Primitive(Option<String>),
285 }
286 impl Iter {
287 pub fn new(input: InputValue, depth: usize) -> Self {
288 match input {
289 InputValue::RichContentPart(rich_content_part) => {
290 Iter::RichContentPart(RichContentPartIter {
291 first: true,
292 part: Some(rich_content_part),
293 last: true,
294 })
295 }
296 InputValue::Object(object) => {
297 Iter::Object(Box::new(ObjectIter {
298 object: object.into_iter(),
299 first: true,
300 child: None,
301 depth,
302 }))
303 }
304 InputValue::Array(array) => Iter::Array(Box::new(ArrayIter {
305 array: array.into_iter(),
306 first: true,
307 child: None,
308 depth,
309 })),
310 InputValue::String(string) => Iter::Primitive(Some(format!(
311 "\"{}\"",
312 json_escape::escape_str(&string),
313 ))),
314 InputValue::Integer(integer) => {
315 Iter::Primitive(Some(integer.to_string()))
316 }
317 InputValue::Number(number) => {
318 Iter::Primitive(Some(number.to_string()))
319 }
320 InputValue::Boolean(boolean) => {
321 Iter::Primitive(Some(boolean.to_string()))
322 }
323 }
324 }
325 }
326 impl Iterator for Iter {
327 type Item = agent::completions::message::RichContentPart;
328 fn next(&mut self) -> Option<Self::Item> {
329 match self {
330 Iter::RichContentPart(rich_content_part_iter) => {
331 rich_content_part_iter.next()
332 }
333 Iter::Object(object_iter) => object_iter.next(),
334 Iter::Array(array_iter) => array_iter.next(),
335 Iter::Primitive(primitive_option) => {
336 primitive_option.take().map(|text| {
337 agent::completions::message::RichContentPart::Text {
338 text,
339 }
340 })
341 }
342 }
343 }
344 }
345 struct RichContentPartIter {
346 first: bool,
347 part: Option<agent::completions::message::RichContentPart>,
348 last: bool,
349 }
350 impl Iterator for RichContentPartIter {
351 type Item = agent::completions::message::RichContentPart;
352 fn next(&mut self) -> Option<Self::Item> {
353 if self.first {
354 self.first = false;
355 Some(agent::completions::message::RichContentPart::Text {
356 text: '"'.to_string(),
357 })
358 } else if let Some(part) = self.part.take() {
359 Some(part)
360 } else if self.last {
361 self.last = false;
362 Some(agent::completions::message::RichContentPart::Text {
363 text: '"'.to_string(),
364 })
365 } else {
366 None
367 }
368 }
369 }
370 struct ObjectIter {
371 object: indexmap::map::IntoIter<String, InputValue>,
372 first: bool,
373 child: Option<Iter>,
374 depth: usize,
375 }
376 impl Iterator for ObjectIter {
377 type Item = agent::completions::message::RichContentPart;
378 fn next(&mut self) -> Option<Self::Item> {
379 if self.first {
380 self.first = false;
381 if let Some((key, input)) = self.object.next() {
382 self.child = Some(Iter::new(input, self.depth + 1));
383 Some(
384 agent::completions::message::RichContentPart::Text {
385 text: format!(
386 "{{\n{}\"{}\": ",
387 " ".repeat(self.depth + 1),
388 key,
389 ),
390 },
391 )
392 } else {
393 Some(
394 agent::completions::message::RichContentPart::Text {
395 text: format!("{{}}"),
396 },
397 )
398 }
399 } else if let Some(child) = &mut self.child {
400 if let Some(part) = child.next() {
401 Some(part)
402 } else if let Some((key, input)) = self.object.next() {
403 self.child = Some(Iter::new(input, self.depth + 1));
404 Some(
405 agent::completions::message::RichContentPart::Text {
406 text: format!(
407 ",\n{}\"{}\": ",
408 " ".repeat(self.depth + 1),
409 key,
410 ),
411 },
412 )
413 } else {
414 self.child = None;
415 Some(
416 agent::completions::message::RichContentPart::Text {
417 text: format!(
418 "\n{}}}",
419 " ".repeat(self.depth)
420 ),
421 },
422 )
423 }
424 } else {
425 None
426 }
427 }
428 }
429 struct ArrayIter {
430 array: std::vec::IntoIter<InputValue>,
431 first: bool,
432 child: Option<Iter>,
433 depth: usize,
434 }
435 impl Iterator for ArrayIter {
436 type Item = agent::completions::message::RichContentPart;
437 fn next(&mut self) -> Option<Self::Item> {
438 if self.first {
439 self.first = false;
440 if let Some(input) = self.array.next() {
441 self.child = Some(Iter::new(input, self.depth + 1));
442 Some(
443 agent::completions::message::RichContentPart::Text {
444 text: format!(
445 "[\n{}",
446 " ".repeat(self.depth + 1)
447 ),
448 },
449 )
450 } else {
451 Some(
452 agent::completions::message::RichContentPart::Text {
453 text: format!("[]"),
454 },
455 )
456 }
457 } else if let Some(child) = &mut self.child {
458 if let Some(part) = child.next() {
459 Some(part)
460 } else if let Some(input) = self.array.next() {
461 self.child = Some(Iter::new(input, self.depth + 1));
462 Some(
463 agent::completions::message::RichContentPart::Text {
464 text: format!(
465 ",\n{}",
466 " ".repeat(self.depth + 1),
467 ),
468 },
469 )
470 } else {
471 self.child = None;
472 Some(
473 agent::completions::message::RichContentPart::Text {
474 text: format!(
475 "\n{}]",
476 " ".repeat(self.depth)
477 ),
478 },
479 )
480 }
481 } else {
482 None
483 }
484 }
485 }
486 Iter::new(self, depth)
487 }
488}
489
490#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
495#[serde(untagged)]
496#[schemars(rename = "functions.expression.InputValueExpression")]
497pub enum InputValueExpression {
498 #[schemars(title = "RichContentPart")]
500 RichContentPart(agent::completions::message::RichContentPart),
501 #[schemars(title = "Object")]
503 Object(IndexMap<String, super::WithExpression<InputValueExpression>>),
504 #[schemars(title = "Array")]
506 Array(Vec<super::WithExpression<InputValueExpression>>),
507 #[schemars(title = "String")]
509 String(String),
510 #[schemars(title = "Integer")]
512 Integer(i64),
513 #[schemars(title = "Number")]
515 Number(f64),
516 #[schemars(title = "Boolean")]
518 Boolean(bool),
519}
520
521impl InputValueExpression {
522 pub fn compile(
524 self,
525 params: &super::Params,
526 ) -> Result<InputValue, super::ExpressionError> {
527 match self {
528 InputValueExpression::RichContentPart(rich_content_part) => {
529 Ok(InputValue::RichContentPart(rich_content_part))
530 }
531 InputValueExpression::Object(object) => {
532 let mut compiled_object = IndexMap::with_capacity(object.len());
533 for (key, value) in object {
534 compiled_object.insert(
535 key,
536 value.compile_one(params)?.compile(params)?,
537 );
538 }
539 Ok(InputValue::Object(compiled_object))
540 }
541 InputValueExpression::Array(array) => {
542 let mut compiled_array = Vec::with_capacity(array.len());
543 for item in array {
544 match item.compile_one_or_many(params)? {
545 super::OneOrMany::One(one_item) => {
546 compiled_array.push(one_item.compile(params)?);
547 }
548 super::OneOrMany::Many(many_items) => {
549 for item in many_items {
550 compiled_array.push(item.compile(params)?);
551 }
552 }
553 }
554 }
555 Ok(InputValue::Array(compiled_array))
556 }
557 InputValueExpression::String(string) => Ok(InputValue::String(string)),
558 InputValueExpression::Integer(integer) => Ok(InputValue::Integer(integer)),
559 InputValueExpression::Number(number) => Ok(InputValue::Number(number)),
560 InputValueExpression::Boolean(boolean) => Ok(InputValue::Boolean(boolean)),
561 }
562 }
563}
564
565impl super::FromStarlarkValue for InputValueExpression {
566 fn from_starlark_value(value: &StarlarkValue) -> Result<Self, super::ExpressionError> {
567 if value.is_none() {
568 return Err(super::ExpressionError::StarlarkConversionError("InputValueExpression: expected value".into()));
569 }
570 if let Ok(Some(b)) = bool::unpack_value(*value) {
571 return Ok(InputValueExpression::Boolean(b));
572 }
573 if let Ok(Some(i)) = i64::unpack_value(*value) {
574 return Ok(InputValueExpression::Integer(i));
575 }
576 if let Ok(Some(UnpackFloat(f))) = UnpackFloat::unpack_value(*value) {
577 return Ok(InputValueExpression::Number(f));
578 }
579 if let Ok(Some(s)) = <&str as UnpackValue>::unpack_value(*value) {
580 return Ok(InputValueExpression::String(s.to_owned()));
581 }
582 if let Some(list) = StarlarkListRef::from_value(*value) {
583 let mut items = Vec::with_capacity(list.len());
584 for v in list.iter() {
585 items.push(super::WithExpression::Value(InputValueExpression::from_starlark_value(&v)?));
586 }
587 return Ok(InputValueExpression::Array(items));
588 }
589 if let Some(dict) = StarlarkDictRef::from_value(*value) {
590 let mut type_value = None;
592 for (k, v) in dict.iter() {
593 if let Ok(Some("type")) = <&str as UnpackValue>::unpack_value(k) {
594 type_value = <&str as UnpackValue>::unpack_value(v).ok().flatten();
595 break;
596 }
597 }
598 if matches!(type_value, Some("text" | "image_url" | "input_audio" | "input_video" | "video_url" | "file")) {
599 if let Ok(part) = agent::completions::message::RichContentPart::from_starlark_value(value) {
600 return Ok(InputValueExpression::RichContentPart(part));
601 }
602 }
603 let mut map = IndexMap::with_capacity(dict.len());
604 for (k, v) in dict.iter() {
605 let key = <&str as UnpackValue>::unpack_value(k)
606 .map_err(|e| super::ExpressionError::StarlarkConversionError(e.to_string()))?
607 .ok_or_else(|| super::ExpressionError::StarlarkConversionError("InputValueExpression: expected string key".into()))?
608 .to_owned();
609 map.insert(key, super::WithExpression::Value(InputValueExpression::from_starlark_value(&v)?));
610 }
611 return Ok(InputValueExpression::Object(map));
612 }
613 Err(super::ExpressionError::StarlarkConversionError(format!(
614 "InputValueExpression: unsupported type: {}",
615 value.get_type()
616 )))
617 }
618}
619
620impl<'a> arbitrary::Arbitrary<'a> for InputValueExpression {
621 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
622 if u.arbitrary().unwrap_or(false) {
623 if u.arbitrary()? {
625 let mut map = IndexMap::new();
626 while u.arbitrary().unwrap_or(false) {
627 map.insert(u.arbitrary::<String>()?, u.arbitrary()?);
628 }
629 Ok(InputValueExpression::Object(map))
630 } else {
631 let mut arr = Vec::new();
632 while u.arbitrary().unwrap_or(false) {
633 arr.push(u.arbitrary()?);
634 }
635 Ok(InputValueExpression::Array(arr))
636 }
637 } else {
638 match u.int_in_range(0..=4)? {
640 0 => Ok(InputValueExpression::RichContentPart(u.arbitrary()?)),
641 1 => Ok(InputValueExpression::String(u.arbitrary()?)),
642 2 => Ok(InputValueExpression::Integer(crate::arbitrary_util::arbitrary_i64(u)?)),
643 3 => Ok(InputValueExpression::Number(crate::arbitrary_util::arbitrary_f64(u)?)),
644 _ => Ok(InputValueExpression::Boolean(u.arbitrary()?)),
645 }
646 }
647 }
648}
649
650impl super::FromSpecial for InputValueExpression {
651 fn from_special(
652 special: &super::Special,
653 params: &super::Params,
654 ) -> Result<Self, super::ExpressionError> {
655 let input = InputValue::from_special(special, params)?;
656 Ok(input_to_input_expression(input))
657 }
658}
659
660fn input_to_input_expression(input: InputValue) -> InputValueExpression {
661 match input {
662 InputValue::RichContentPart(p) => InputValueExpression::RichContentPart(p),
663 InputValue::Object(map) => InputValueExpression::Object(
664 map.into_iter()
665 .map(|(k, v)| {
666 (
667 k,
668 super::WithExpression::Value(input_to_input_expression(v)),
669 )
670 })
671 .collect(),
672 ),
673 InputValue::Array(arr) => InputValueExpression::Array(
674 arr.into_iter()
675 .map(|v| super::WithExpression::Value(input_to_input_expression(v)))
676 .collect(),
677 ),
678 InputValue::String(s) => InputValueExpression::String(s),
679 InputValue::Integer(i) => InputValueExpression::Integer(i),
680 InputValue::Number(n) => InputValueExpression::Number(n),
681 InputValue::Boolean(b) => InputValueExpression::Boolean(b),
682 }
683}