systemprompt_provider_contracts/
extender.rs1use std::any::Any;
2
3use anyhow::Result;
4use async_trait::async_trait;
5use serde_json::Value;
6
7pub struct ExtenderContext<'a> {
8 pub item: &'a Value,
9 pub all_items: &'a [Value],
10 pub config: &'a serde_yaml::Value,
11 pub web_config: &'a serde_yaml::Value,
12 pub content_html: &'a str,
13 pub url_pattern: &'a str,
14 pub source_name: &'a str,
15 db_pool: &'a (dyn Any + Send + Sync),
16}
17
18impl std::fmt::Debug for ExtenderContext<'_> {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 f.debug_struct("ExtenderContext")
21 .field("item", &self.item)
22 .field("all_items", &format!("[{} items]", self.all_items.len()))
23 .field(
24 "content_html",
25 &format!("[{} chars]", self.content_html.len()),
26 )
27 .field("url_pattern", &self.url_pattern)
28 .field("source_name", &self.source_name)
29 .field("db_pool", &"<dyn Any>")
30 .finish()
31 }
32}
33
34pub struct ExtenderContextBuilder<'a> {
35 item: &'a Value,
36 all_items: &'a [Value],
37 config: &'a serde_yaml::Value,
38 web_config: &'a serde_yaml::Value,
39 db_pool: &'a (dyn Any + Send + Sync),
40 content_html: &'a str,
41 url_pattern: &'a str,
42 source_name: &'a str,
43}
44
45impl std::fmt::Debug for ExtenderContextBuilder<'_> {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 f.debug_struct("ExtenderContextBuilder")
48 .field("item", &self.item)
49 .field("all_items", &format!("[{} items]", self.all_items.len()))
50 .field(
51 "content_html",
52 &format!("[{} chars]", self.content_html.len()),
53 )
54 .field("url_pattern", &self.url_pattern)
55 .field("source_name", &self.source_name)
56 .field("db_pool", &"<dyn Any>")
57 .finish()
58 }
59}
60
61impl<'a> ExtenderContextBuilder<'a> {
62 #[must_use]
63 pub fn new(
64 item: &'a Value,
65 all_items: &'a [Value],
66 config: &'a serde_yaml::Value,
67 web_config: &'a serde_yaml::Value,
68 db_pool: &'a (dyn Any + Send + Sync),
69 ) -> Self {
70 Self {
71 item,
72 all_items,
73 config,
74 web_config,
75 db_pool,
76 content_html: "",
77 url_pattern: "",
78 source_name: "",
79 }
80 }
81
82 #[must_use]
83 pub const fn with_content_html(mut self, content_html: &'a str) -> Self {
84 self.content_html = content_html;
85 self
86 }
87
88 #[must_use]
89 pub const fn with_url_pattern(mut self, url_pattern: &'a str) -> Self {
90 self.url_pattern = url_pattern;
91 self
92 }
93
94 #[must_use]
95 pub const fn with_source_name(mut self, source_name: &'a str) -> Self {
96 self.source_name = source_name;
97 self
98 }
99
100 #[must_use]
101 pub fn build(self) -> ExtenderContext<'a> {
102 ExtenderContext {
103 item: self.item,
104 all_items: self.all_items,
105 config: self.config,
106 web_config: self.web_config,
107 content_html: self.content_html,
108 url_pattern: self.url_pattern,
109 source_name: self.source_name,
110 db_pool: self.db_pool,
111 }
112 }
113}
114
115impl<'a> ExtenderContext<'a> {
116 #[must_use]
117 pub fn builder(
118 item: &'a Value,
119 all_items: &'a [Value],
120 config: &'a serde_yaml::Value,
121 web_config: &'a serde_yaml::Value,
122 db_pool: &'a (dyn Any + Send + Sync),
123 ) -> ExtenderContextBuilder<'a> {
124 ExtenderContextBuilder::new(item, all_items, config, web_config, db_pool)
125 }
126
127 #[must_use]
128 pub fn db_pool<T: 'static>(&self) -> Option<&T> {
129 self.db_pool.downcast_ref::<T>()
130 }
131}
132
133#[derive(Debug)]
134pub struct ExtendedData {
135 pub variables: Value,
136 pub priority: u32,
137}
138
139impl ExtendedData {
140 #[must_use]
141 pub const fn new(variables: Value) -> Self {
142 Self {
143 variables,
144 priority: 100,
145 }
146 }
147
148 #[must_use]
149 pub const fn with_priority(variables: Value, priority: u32) -> Self {
150 Self {
151 variables,
152 priority,
153 }
154 }
155}
156
157#[async_trait]
158pub trait TemplateDataExtender: Send + Sync {
159 fn extender_id(&self) -> &str;
160
161 fn applies_to(&self) -> Vec<String> {
162 vec![]
163 }
164
165 async fn extend(&self, ctx: &ExtenderContext<'_>, data: &mut Value) -> Result<()>;
166
167 fn priority(&self) -> u32 {
168 100
169 }
170}