import { parallel_judge } from "std/llm/judge"
import { propose_instructions } from "std/llm/refine"
fn __require_config(config) {
require type_of(config) == "dict", "optimize_prompt: config must be a dict"
require type_of(config?.base_prompt) == "string", "optimize_prompt: base_prompt must be a string"
require type_of(config?.eval_set) == "list", "optimize_prompt: eval_set must be a list"
require len(config.eval_set) > 0, "optimize_prompt: eval_set must not be empty"
require config?.metric != nil, "optimize_prompt: metric closure is required"
return config
}
fn __positive_int(value, fallback, label) {
let raw = value ?? fallback
let n = to_int(raw)
require n != nil && n > 0, label + " must be a positive integer"
return n
}
fn __demo_sets(config) {
let demos = config?.demo_sets ?? config?.demos
if demos == nil {
return [[]]
}
require type_of(demos) == "list", "optimize_prompt: demos must be a list"
if len(demos) == 0 {
return [[]]
}
if type_of(demos[0]) == "list" {
return demos
}
return [demos]
}
fn __demo_text(demo) {
if type_of(demo) == "string" {
return demo
}
if type_of(demo) == "dict" {
if demo?.input != nil && demo?.output != nil {
return "Input: " + to_string(demo.input) + "\nOutput: " + to_string(demo.output)
}
if demo?.question != nil && demo?.answer != nil {
return "Question: " + to_string(demo.question) + "\nAnswer: " + to_string(demo.answer)
}
}
return json_stringify(demo)
}
fn __render_prompt(base_prompt, instruction, demos) {
var parts = []
let instruction_text = trim(instruction ?? "")
let base_text = trim(base_prompt)
if instruction_text != "" {
parts = parts.push(instruction_text)
}
if base_text != "" && base_text != instruction_text {
parts = parts.push(base_text)
}
if len(demos) > 0 {
var rendered = []
for demo in demos {
rendered = rendered.push(__demo_text(demo))
}
parts = parts.push("Examples:\n" + join(rendered, "\n\n"))
}
return join(parts, "\n\n")
}
fn __candidate_space(config) {
let instructions = propose_instructions(config.base_prompt, config)
let demo_sets = __demo_sets(config)
var candidates = []
var idx = 0
for instruction in instructions {
for demos in demo_sets {
candidates = candidates
.push(
{
index: idx,
instruction: instruction,
demos: demos,
prompt: __render_prompt(config.base_prompt, instruction, demos),
},
)
idx = idx + 1
}
}
return candidates
}
fn __exploration(config) {
let value = to_float(config?.acquisition?.exploration ?? config?.exploration ?? 0.25)
require value != nil, "optimize_prompt: acquisition.exploration must be numeric"
return value
}
fn __judge_options(config) {
return config?.budget ?? {} + config?.judge ?? {}
}
fn __trial_limit(config, candidate_count) {
var limit = __positive_int(config?.trials, candidate_count, "optimize_prompt: trials")
if config?.budget?.max_trials != nil {
let max_trials = __positive_int(config.budget.max_trials, limit, "optimize_prompt: budget.max_trials")
if max_trials < limit {
limit = max_trials
}
}
if config?.budget?.max_evaluations != nil {
let max_evals = __positive_int(
config.budget.max_evaluations,
len(config.eval_set),
"optimize_prompt: budget.max_evaluations",
)
let eval_trials = to_int(floor(max_evals / len(config.eval_set)))
if eval_trials < limit {
limit = eval_trials
}
}
if limit > candidate_count {
limit = candidate_count
}
require limit > 0, "optimize_prompt: budget leaves no eval trials"
return limit
}
fn __words(text) {
let normalized = regex_replace("[^A-Za-z0-9_]+", " ", lowercase(to_string(text)))
return split(trim(normalized), " ").filter({ word -> word != "" })
}
fn __text_similarity(a, b) {
let left = __words(a)
let right = __words(b)
if len(left) == 0 || len(right) == 0 {
return 0.0
}
var common = 0
for word in left {
if right.contains(word) {
common = common + 1
}
}
let denom = if len(left) > len(right) {
len(left)
} else {
len(right)
}
return common / (denom * 1.0)
}
fn __candidate_similarity(candidate, observed) {
var similarity = __text_similarity(candidate.instruction, observed.candidate.instruction)
if len(candidate.demos) == len(observed.candidate.demos) {
similarity = similarity + 0.2
}
if similarity > 1.0 {
return 1.0
}
return similarity
}
fn __acquisition(candidate, observations, best_score, exploration) {
if len(observations) == 0 {
return {
kind: "expected_improvement",
predicted_mean: 0.0,
uncertainty: 1.0,
expected_improvement: exploration,
value: exploration,
}
}
var weighted = 0.0
var weight_total = 0.0
for observed in observations {
let weight = 0.05 + __candidate_similarity(candidate, observed)
weighted = weighted + weight * observed.score
weight_total = weight_total + weight
}
let predicted = weighted / weight_total
let uncertainty = 1.0 / (1.0 + weight_total)
let expected = predicted - best_score + exploration * uncertainty
return {
kind: "expected_improvement",
predicted_mean: predicted,
uncertainty: uncertainty,
expected_improvement: expected,
value: expected,
}
}
fn __next_candidate(candidates, observed_indexes, observations, best_score, exploration) {
var selected = nil
var selected_acquisition = nil
for candidate in candidates {
if !observed_indexes.contains(candidate.index) {
let acquisition = __acquisition(candidate, observations, best_score, exploration)
if selected == nil || acquisition.value > selected_acquisition.value {
selected = candidate
selected_acquisition = acquisition
}
}
}
return {candidate: selected, acquisition: selected_acquisition}
}
fn __insert_ranked(ranked, entry) {
var out = []
var inserted = false
for existing in ranked {
if !inserted && entry.score > existing.score {
out = out.push(entry)
inserted = true
}
out = out.push(existing)
}
if !inserted {
out = out.push(entry)
}
return out
}
fn __rank_observations(observations) {
var ranked = []
for observation in observations {
ranked = __insert_ranked(
ranked,
{
trial: observation.trial,
candidate: observation.candidate,
prompt: observation.candidate.prompt,
score: observation.score,
case_scores: observation.case_scores,
},
)
}
return ranked
}
/** Optimize a prompt over instruction/demo candidates using deterministic acquisition search. */
pub fn optimize_prompt(config) {
let cfg = __require_config(config)
let candidates = __candidate_space(cfg)
require len(candidates) > 0, "optimize_prompt: no candidate prompts generated"
let trials = __trial_limit(cfg, len(candidates))
let exploration = __exploration(cfg)
let judge_options = __judge_options(cfg)
var observed_indexes = []
var observations = []
var trace = []
var best = nil
var best_score = 0.0
var trial = 0
while trial < trials {
let picked = if trial == 0 {
{
candidate: candidates[0],
acquisition: {
kind: "expected_improvement",
predicted_mean: 0.0,
uncertainty: 1.0,
expected_improvement: exploration,
value: exploration,
seed: true,
},
}
} else {
__next_candidate(candidates, observed_indexes, observations, best_score, exploration)
}
if picked.candidate == nil {
break
}
let judged = parallel_judge([picked.candidate], cfg.eval_set, cfg.metric, judge_options)
let score = judged.scores[0].score
let observation = {
trial: trial,
candidate: picked.candidate,
score: score,
acquisition: picked.acquisition,
case_scores: judged.scores[0].case_scores,
}
observations = observations.push(observation)
observed_indexes = observed_indexes.push(picked.candidate.index)
trace = trace.push(observation)
if best == nil || score > best.score {
best = observation
best_score = score
}
trial = trial + 1
}
let ranked = __rank_observations(observations)
return {
best_prompt: best.candidate.prompt,
best_score: best.score,
best_candidate: best.candidate,
trace: trace,
ranked: ranked,
candidates: candidates,
}
}